From 7ecbea92a5d35aab6cdef4311b5cc523206d0eba Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Mon, 30 Oct 2023 23:46:50 -0400 Subject: [PATCH 01/36] Trying to integrate network stack, failing to get interrupts --- README.md | 22 +- kernel/Cargo.toml | 1 + kernel/src/interrupts.rs | 71 +++- kernel/src/main.rs | 69 +++- kernel/src/network/README.md | 21 ++ kernel/src/network/TODO.md | 12 + kernel/src/network/arp.rs | 98 ++++++ kernel/src/network/arp_table.rs | 5 + kernel/src/network/bytefield.rs | 304 ++++++++++++++++ kernel/src/network/command_register.rs | 101 ++++++ kernel/src/network/constants.rs | 24 ++ kernel/src/network/devices.rs | 162 +++++++-- kernel/src/network/dhcp.rs | 162 +++++++++ kernel/src/network/e1000.rs | 211 +++++++++++ kernel/src/network/ethernet.rs | 91 +++++ kernel/src/network/icmp.rs | 0 kernel/src/network/init.rs | 63 ++++ kernel/src/network/ip.rs | 170 ++++++--- kernel/src/network/layer.rs | 172 +++++++++ kernel/src/network/mod.rs | 19 +- kernel/src/network/netsync.rs | 62 ++++ kernel/src/network/raw_array.rs | 98 ++++++ kernel/src/network/rtl8139.rs | 464 +++++++++++++++++++++++++ kernel/src/network/socket.rs | 94 +++++ kernel/src/network/tcp.rs | 0 kernel/src/network/udp.rs | 141 ++++++++ src/main.rs | 11 + 27 files changed, 2545 insertions(+), 103 deletions(-) create mode 100644 kernel/src/network/README.md create mode 100644 kernel/src/network/TODO.md create mode 100644 kernel/src/network/arp.rs create mode 100644 kernel/src/network/arp_table.rs create mode 100644 kernel/src/network/bytefield.rs create mode 100644 kernel/src/network/command_register.rs create mode 100644 kernel/src/network/constants.rs create mode 100644 kernel/src/network/dhcp.rs create mode 100644 kernel/src/network/e1000.rs create mode 100644 kernel/src/network/ethernet.rs create mode 100644 kernel/src/network/icmp.rs create mode 100644 kernel/src/network/init.rs create mode 100644 kernel/src/network/layer.rs create mode 100644 kernel/src/network/netsync.rs create mode 100644 kernel/src/network/raw_array.rs create mode 100644 kernel/src/network/rtl8139.rs create mode 100644 kernel/src/network/socket.rs create mode 100644 kernel/src/network/tcp.rs create mode 100644 kernel/src/network/udp.rs diff --git a/README.md b/README.md index 5c011d3..9a4c37e 100644 --- a/README.md +++ b/README.md @@ -2,17 +2,27 @@ ## Build & Run -### Build kernel +### Editing QEMU Configuration -`cargo build` +Can edit src/main.rs to add command line arguments to qemu -### Build kernel and link with bootloader - -`cargo bootimage` +```{rust} +cmd.arg("-FLAG").arg("FLAG_ARG"); +``` ### Boot OS in virtual machine -`cargo run` +`sudo cargo run` + +Adding sudo to enable certain networking requirements. + +### Interacting with the machine over the network + +Can forward packets into the network (just for testing, for example) like + +`nc -4 -u localhost 5555` + +TODO: Make a python script for communicating WASM code with the OS... (and document it) ## Testing diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 2421178..ffb5708 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -17,6 +17,7 @@ pc-keyboard = "0.7.0" linked_list_allocator = "0.10.5" pci = "0.0.1" noto-sans-mono-bitmap = "0.2.0" +hashbrown = "0.14.2" [dependencies.wasmi] version = "0.31.0" diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index f4384fb..18f7334 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -1,17 +1,17 @@ -use crate::gdt; -use crate::hlt_loop; -use crate::println; use lazy_static::lazy_static; +use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame}; +use crate::{gdt, println, print}; use pic8259::ChainedPics; use spin; use x86_64::structures::idt::PageFaultErrorCode; -use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame}; +use crate::hlt_loop; pub const PIC_1_OFFSET: u8 = 32; pub const PIC_2_OFFSET: u8 = PIC_1_OFFSET + 8; -pub static PICS: spin::Mutex = - spin::Mutex::new(unsafe { ChainedPics::new(PIC_1_OFFSET, PIC_2_OFFSET) }); +pub static PICS: spin::Mutex = spin::Mutex::new({ + unsafe { ChainedPics::new(PIC_1_OFFSET, PIC_2_OFFSET) } +}); #[derive(Debug, Clone, Copy)] #[repr(u8)] @@ -21,7 +21,7 @@ pub enum InterruptIndex { } impl InterruptIndex { - fn as_u8(self) -> u8 { + pub fn as_u8(self) -> u8 { self as u8 } @@ -30,24 +30,67 @@ impl InterruptIndex { } } -lazy_static! { - static ref IDT: InterruptDescriptorTable = { +pub struct InterruptHandler { + idt: InterruptDescriptorTable +} + +pub type InterruptHandlerFunc = extern "x86-interrupt" fn (InterruptStackFrame) -> (); +impl InterruptHandler { + pub fn new() -> Self { let mut idt = InterruptDescriptorTable::new(); idt.breakpoint.set_handler_fn(breakpoint_handler); idt.page_fault.set_handler_fn(page_fault_handler); unsafe { - idt.double_fault - .set_handler_fn(double_fault_handler) + idt.double_fault.set_handler_fn(double_fault_handler) .set_stack_index(gdt::DOUBLE_FAULT_IST_INDEX); } idt[InterruptIndex::Timer.as_usize()].set_handler_fn(timer_interrupt_handler); idt[InterruptIndex::Keyboard.as_usize()].set_handler_fn(keyboard_interrupt_handler); - return idt; + InterruptHandler { idt } + } + + pub fn init(&self) -> (){ + unsafe { self.idt.load_unsafe() }; + } + + // Static function for disabling an irq + pub fn unblock_irq(irq_num: u8) -> () { + let data = unsafe { PICS.lock().read_masks() }; + // set the irq bit to 0 + if irq_num < 8 { + unsafe { PICS.lock().write_masks(data[0] & !(1 << irq_num), data[1]) }; + } else { + unsafe { PICS.lock().write_masks(data[0], data[1] & !(1 << irq_num - 8)) }; + } + } + + // Static function for re-enabling an IRQ + pub fn block_irq(irq_num: u8) -> () { + let data = unsafe { PICS.lock().read_masks() }; + // set the irq bit to 1 + if irq_num < 8 { + unsafe { PICS.lock().write_masks(data[0] | 1 << irq_num, data[1]) }; + } else { + unsafe { PICS.lock().write_masks(data[0], data[1] | 1 << irq_num - 8) }; + } + } + + pub fn register_irq(&mut self, irq_num: usize, handler: InterruptHandlerFunc) -> (){ + println!("Registered Handler @ {}", irq_num + 32); + self.idt[irq_num + 32].set_handler_fn(handler); + unsafe { self.idt.load_unsafe() }; + println!("Registered IRQ @ {}", irq_num); + } +} + +lazy_static! { + pub static ref IDT: spin::Mutex = { + spin::Mutex::new(InterruptHandler::new()) }; } pub fn init_idt() { - IDT.load(); + IDT.lock().init(); } extern "x86-interrupt" fn breakpoint_handler(stack_frame: InterruptStackFrame) { @@ -75,7 +118,7 @@ extern "x86-interrupt" fn double_fault_handler( } extern "x86-interrupt" fn timer_interrupt_handler(_stack_frame: InterruptStackFrame) { - // print!("."); + print!("."); unsafe { PICS.lock() diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 94b8ff8..4a2a957 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -4,6 +4,7 @@ #![test_runner(kernel::test_runner)] #![reexport_test_harness_main = "test_main"] +use alloc::string::String; use bootloader_api::{ config::{BootloaderConfig, Mapping}, entry_point, BootInfo, @@ -11,7 +12,7 @@ use bootloader_api::{ use core::panic::PanicInfo; use kernel::{ framebuffer, hlt_loop, - network::devices, + network::{ethernet::{EthernetPacket, self}, udp::UDPPacket, ip::{IPPacket, Protocol}, layer::{Layer, HasChecksum, LayerType}, rtl8139::{disable_network_interrupts, NET_INFO, enable_network_interrupts}, socket::RawSocket, init::init_dhcp}, println, task::keyboard, task::{executor::Executor, Task}, @@ -59,20 +60,66 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { kernel::init(); framebuffer::setup(boot_info); - let phys_mem_offset = VirtAddr::new(boot_info.physical_memory_offset.into_option().unwrap()); - let mut mapper = unsafe { memory::init(phys_mem_offset) }; + let phys_mem_offset = boot_info.physical_memory_offset.into_option().unwrap(); + let mut mapper = unsafe { memory::init(VirtAddr::new(phys_mem_offset)) }; let mut frame_allocator = unsafe { BootInfoFrameAllocator::init(&boot_info.memory_regions) }; allocator::init_heap(&mut mapper, &mut frame_allocator).expect("heap init failed"); - println!("Pad..."); - println!("Pad..."); - println!("Pad..."); - println!("Pad..."); - - devices::scan_devices(); - - println!("Hello World!"); + let status_init = { + // empty scope + let mut rtl_driver_lock = NET_INFO.lock(); + let rtl_driver = rtl_driver_lock.get_mut(); + if rtl_driver.is_none() { + panic!("Cannot find network card"); + } + rtl_driver.unwrap().init(&mut frame_allocator, phys_mem_offset) + }; // so that the NET INFO gets released + let status_init_dhcp = init_dhcp(2); + if !status_init && false { + println!("[ERR] Cannot init RTL8139"); + } else if !status_init_dhcp && false { + println!("[ERR] DHCP error -- whats my ip?"); + } else { + let raw_socket = RawSocket::new(5554); + match raw_socket { + Ok(mut socket) => { + loop { + let pkt = socket.get_packet(); + if pkt.get_type() != LayerType::UDP { break; } + let udp_pkt = pkt.unwrap_udp(); + let data_cloned = udp_pkt.data.clone(); + let data_cloned_len = data_cloned.len(); + let user_message = String::from_utf8(udp_pkt.data); + match user_message { + Ok(message) => println!("[USER] {}", message), + Err(err) => println!("[USER-ERR] {:?}", err), + } + + // send back a copy of the packet ("echo") + let ip_layer_res = udp_pkt.ip_packet; + let eth_layer_res = ip_layer_res.ethernet_packet; + let eth_layer = EthernetPacket::gen(eth_layer_res.src_mac.val(), eth_layer_res.dest_mac.val(), ethernet::EthType::IPv4); + let udp_size = UDPPacket::packet_size() + data_cloned_len as u16; + let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::UDP, ip_layer_res.destination_ip.val(), ip_layer_res.source_ip.val()); + let mut udp_layer = UDPPacket::gen(ip_layer, udp_pkt.dest_port.val(), udp_pkt.src_port.val(), data_cloned_len as u16); + udp_layer.data = data_cloned; + let data_2_send = udp_layer.serialize(); + let start_udp = data_2_send.len() - (UDPPacket::packet_size() as usize + data_cloned_len); + let start_ip = start_udp - (IPPacket::packet_size() as usize); + udp_layer.ip_packet.calculate_checksum(&data_2_send[start_ip..start_udp]); + udp_layer.calculate_checksum(&data_2_send[start_udp..]); + let data_2_send_final = udp_layer.serialize(); + disable_network_interrupts(); + NET_INFO.lock().get_ref().unwrap().send_packet(&data_2_send_final); + enable_network_interrupts(); + } + println!("[INFO] Socket is closing"); + socket.close(); + }, + Err(err) => println!("{:?}", err), + } + } #[cfg(test)] test_main(); diff --git a/kernel/src/network/README.md b/kernel/src/network/README.md new file mode 100644 index 0000000..ab4ddc2 --- /dev/null +++ b/kernel/src/network/README.md @@ -0,0 +1,21 @@ +# Documentation of the network stack + +TODO + +## TODOS + +[x] PCI scanning for devices +[x] RTL8139 Driver Code +[x] Ethernet, IP, UDP, ARP, DHCP +[x] RawSocket API +[] Better Socket API +[] Refactor so that all of networking is tested +[] Refactor to include more documentation on the network module +[] Refactor to verify checksums +[] Verify other parts of the packet +[] Fix synchronization to be much cleaner +[] Clean up ugly stuff +[] DHCP parse additional options +[] TCP +[] Refactor to be all constants +[] search for todo and fix thoses diff --git a/kernel/src/network/TODO.md b/kernel/src/network/TODO.md new file mode 100644 index 0000000..064f2f1 --- /dev/null +++ b/kernel/src/network/TODO.md @@ -0,0 +1,12 @@ +# TODOs + +* Refactor so that all of networking is tested +* Refactor to include more documentation on the network module +* Refactor to verify checksums +* Verify other parts of the packet +* Fix synchronization to be much cleaner +* Clean up ugly stuff +* DHCP parse additional options +* TCP +* Refactor to be all constants +* search for todo and fix thoses diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs new file mode 100644 index 0000000..c1cfab5 --- /dev/null +++ b/kernel/src/network/arp.rs @@ -0,0 +1,98 @@ +use super::{ + bytefield::{Bytefield16, Bytefield32, Bytefield48, Bytefield8}, + ethernet::{EthernetPacket, EthType}, + layer::{Layer, LayerType}, +}; +use alloc::vec; +use alloc::vec::Vec; +#[derive(Debug)] +pub struct ArpPacket { + pub ethernet_packet: EthernetPacket, + hardware_type: Bytefield16, + protocol_type: Bytefield16, + hardware_address_length: u8, + protocol_address_length: u8, + operation: Bytefield16, + pub sender_mac: Bytefield48, + pub sender_ip: Bytefield32, + pub recp_mac: Bytefield48, + pub recp_ip: Bytefield32, +} + +impl ArpPacket { + // Create an empty packet with all 0s + pub fn new() -> Self { + ArpPacket { + ethernet_packet: EthernetPacket::new(), + hardware_type: Bytefield16::new(0), + protocol_type: Bytefield16::new(0), + hardware_address_length: 0, + protocol_address_length: 0, + operation: Bytefield16::new(0), + sender_mac: Bytefield48::new(0), + sender_ip: Bytefield32::new(0), + recp_mac: Bytefield48::new(0), + recp_ip: Bytefield32::new(0), + } + } + + pub fn gen(eth_layer: EthernetPacket, source_ip: u32, recp_ip: u32, is_req: bool) -> Self { + let recp_mac = eth_layer.dest_mac; + let sender_mac = eth_layer.src_mac; + assert!(eth_layer.packet_type == EthType::Arp); + ArpPacket { + ethernet_packet: eth_layer, + hardware_type: Bytefield16::new(0x1), // ethernet + protocol_type: Bytefield16::new(0x0800), // ipv4 + hardware_address_length: 6, // ethernet is the value 6 + protocol_address_length: 4, // ethernet is the value 4 + operation: Bytefield16::new(if is_req { 1 } else { 2 }), // Request is 1, Reply is 2 + sender_mac, + sender_ip: Bytefield32::new(source_ip), // what is my ip + recp_mac, + recp_ip: Bytefield32::new(if is_req { 0 } else { recp_ip }), + } + } +} + +impl Layer for ArpPacket { + type Input = EthernetPacket; + fn parse(eth_layer: EthernetPacket, bytevec: &[u8]) -> (Self, usize, LayerType) { + let mut packet = ArpPacket::new(); // create an empty packet + + // Read ethernet packet and 28 bytes + let mut i = 0; + packet.ethernet_packet = eth_layer; + packet.hardware_type = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.protocol_type = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.protocol_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.operation = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.sender_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); + packet.sender_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.recp_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); + packet.recp_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + + assert!(i == 28); // Arp packet should be 28 bytes + return (packet, i, LayerType::UNDEF); + } + + fn serialize(&self) -> Vec { + let mut res = vec![]; + res.extend(self.ethernet_packet.serialize()); + res.extend(self.hardware_type.data); + res.extend(self.protocol_type.data); + res.push(self.hardware_address_length); + res.push(self.protocol_address_length); + res.extend(self.operation.data); + res.extend(self.sender_mac.data); + res.extend(self.sender_ip.data); + res.extend(self.recp_mac.data); + res.extend(self.recp_ip.data); + res + } + + fn packet_size() -> u16 { + 28 + } +} diff --git a/kernel/src/network/arp_table.rs b/kernel/src/network/arp_table.rs new file mode 100644 index 0000000..e4cd1fd --- /dev/null +++ b/kernel/src/network/arp_table.rs @@ -0,0 +1,5 @@ +pub struct ArpEntry { + pub mac: u64, + pub ip: u32, + pub expires: u16, +} \ No newline at end of file diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs new file mode 100644 index 0000000..f2581e7 --- /dev/null +++ b/kernel/src/network/bytefield.rs @@ -0,0 +1,304 @@ +use core::ops::{Index, IndexMut}; + +// N.B.: BytefieldS STORE IN BIG ENDIAN (as per network requirements) + +#[derive(Debug, Clone, Copy)] +pub struct Bytefield128 { + pub data: [u8; 16], +} + +impl Bytefield128 { + pub fn new(val: u128) -> Self { + let mut data = [0; Self::size()]; + for i in 0..Self::size() { + data[i] = (val >> ((Self::size() - 1 - i) * 8) & 0xFF) as u8; + } + Bytefield128 { data } + } + + pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { + *i += Self::size(); + let mut data = [0; Self::size()]; + for i in 0..Self::size() { + data[i] = bytevec[Self::size() - 1 - i]; + } + Bytefield128 { data } + } + + // Get the value in swapped endianness (for example, if parsing a web thing, you'll get the little endian version) + pub fn val(&self) -> u128 { + let mut res = 0_u128; + for i in 0..Self::size() { + res = res | ((self.data[i] as u128) << (i * 8)); + } + return res; + } + + // Swap the endianness of the data + pub fn swap_endianness(&mut self) -> () { + self.data.reverse(); + } + + // Get the number of bytes + pub const fn size() -> usize { + 16 + } +} + +#[derive(Clone, Copy)] +pub struct Bytefield64 { + pub data: [u8; 8], +} + +impl Bytefield64 { + pub fn new(val: u64) -> Self { + let mut data = [0; Self::size()]; + for i in 0..Self::size() { + data[i] = (val >> ((Self::size() - 1 - i) * 8) & 0xFF) as u8; + } + Self { data } + } + + pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { + *i += Self::size(); + let mut data = [0; Self::size()]; + for i in 0..Self::size() { + data[i] = bytevec[Self::size() - 1 - i]; + } + Self { data } + } + + // Get the value in swapped endianness (for example, if parsing a web thing, you'll get the little endian version) + pub fn val(&self) -> u64 { + let mut res = 0_u64; + for i in 0..Self::size() { + res = res | ((self.data[i] as u64) << (i * 8)); + } + return res; + } + + // Swap the endianness of the data + pub fn swap_endianness(&mut self) -> () { + self.data.reverse(); + } + + // Get the number of bytes + pub const fn size() -> usize { + 8 + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Bytefield48 { + pub data: [u8; 6], +} + +impl Bytefield48 { + pub fn new(val: u64) -> Self { + let mut data = [0; Self::size()]; + for i in 0..Self::size() { + data[i] = (val >> ((Self::size() - 1 - i) * 8) & 0xFF) as u8; + } + Self { data } + } + + pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { + *i += Self::size(); + let mut data = [0; Self::size()]; + for i in 0..Self::size() { + data[i] = bytevec[Self::size() - 1 - i]; + } + Self { data } + } + + // Get the value in swapped endianness + pub fn val(&self) -> u64 { + let mut res = 0_u64; + for i in 0..Self::size() { + res = res | ((self.data[i] as u64) << (i * 8)); + } + return res; + } + + // Swap the endianness of the data + pub fn swap_endianness(&mut self) -> () { + self.data.reverse(); + } + + // Get the number of bytes + pub const fn size() -> usize { + 6 + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Bytefield32 { + pub data: [u8; 4], +} + +impl Bytefield32 { + pub fn new(val: u32) -> Self { + let mut data = [0; Self::size()]; + for i in 0..Self::size() { + data[i] = (val >> ((Self::size() - 1 - i) * 8) & 0xFF) as u8; + } + Self { data } + } + + pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { + *i += Self::size(); + let mut data = [0; Self::size()]; + for i in 0..Self::size() { + data[i] = bytevec[Self::size() - 1 - i]; + } + Self { data } + } + + // Get the value in swapped endianness (for example, if parsing a web thing, you'll get the little endian version) + pub fn val(&self) -> u32 { + let mut res = 0_u32; + for i in 0..Self::size() { + res = res | ((self.data[i] as u32) << (i * 8)); + } + return res; + } + + // Swap the endianness of the data + pub fn swap_endianness(&mut self) -> () { + self.data.reverse(); + } + + // Get the number of bytes + pub const fn size() -> usize { + 4 + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Bytefield16 { + pub data: [u8; 2], +} + +impl Bytefield16 { + // Create a bytefield and swap endian-ness + pub fn new(val: u16) -> Self { + Self { data: [(val >> 1 * 8 & 0xFF) as u8, (val >> 0 * 8 & 0xFF) as u8] } + } + + pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { + *i += Self::size(); + Self { data: [bytevec[1], bytevec[0]] } + } + + /// Get the original value used to create the bytefield (will preserve the endian-ness or the parsed data) + pub fn val(&self) -> u16 { + let mut res = 0_u16; + for i in 0..Self::size() { + res = res | ((self.data[i] as u16) << (i * 8)); + } + return res; + } + + // Swap the endianness of the data + pub fn swap_endianness(&mut self) -> () { + self.data.reverse(); + } + + pub fn size() -> usize { + 2 + } +} + +#[derive(Clone, Copy)] +pub struct Bytefield8 { + pub data: u8, +} + +impl Bytefield8 { + pub fn new(data: u8) -> Self { + Bytefield8 { data } + } + + pub fn read(bytevec: &[u8]) -> Self { + Bytefield8 { data: bytevec[0] } + } + + pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { + *i += Self::size(); + Self::read(bytevec) + } + + pub fn val(&self) -> u8 { + self.data + } + + pub fn size() -> usize { + 1 + } +} + +// ===== DEFINING INDEXING OPERATIONS ====== // +impl Index for Bytefield128 { + type Output = u8; + fn index<'a>(&'a self, i: usize) -> &'a u8 { + &self.data[i] + } +} + +impl IndexMut for Bytefield128 { + fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + &mut self.data[i] + } +} + +impl Index for Bytefield64 { + type Output = u8; + fn index<'a>(&'a self, i: usize) -> &'a u8 { + &self.data[i] + } +} + +impl IndexMut for Bytefield64 { + fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + &mut self.data[i] + } +} + +impl Index for Bytefield48 { + type Output = u8; + fn index<'a>(&'a self, i: usize) -> &'a u8 { + &self.data[i] + } +} + +impl IndexMut for Bytefield48 { + fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + &mut self.data[i] + } +} + +impl Index for Bytefield32 { + type Output = u8; + fn index<'a>(&'a self, i: usize) -> &'a u8 { + &self.data[i] + } +} + +impl IndexMut for Bytefield32 { + fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + &mut self.data[i] + } +} + +impl Index for Bytefield16 { + type Output = u8; + fn index<'a>(&'a self, i: usize) -> &'a u8 { + &self.data[i] + } +} + +impl IndexMut for Bytefield16 { + fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + &mut self.data[i] + } +} \ No newline at end of file diff --git a/kernel/src/network/command_register.rs b/kernel/src/network/command_register.rs new file mode 100644 index 0000000..acc1249 --- /dev/null +++ b/kernel/src/network/command_register.rs @@ -0,0 +1,101 @@ +#[derive(Debug, Clone)] +pub struct CommandRegister { + cr: u16, +} + +impl CommandRegister { + pub fn new(cr: u16) -> Self { + CommandRegister { cr } + } + + // basic getter for internal data + pub fn data(&self) -> u16 { + return self.cr; + } + + // 0th bit + pub fn get_io_space_bit(&self) -> bool { + return (self.cr & 0x1) != 0; + } + pub fn set_io_space_bit(&mut self, is_on: bool) -> (){ + match is_on { + true => self.cr = self.cr | 0x1, + false => self.cr = self.cr & !0x1, + } + } + + // 1st bit + pub fn get_memory_space_bit(&self) -> bool { + return (self.cr & 0x2) != 0; + } + pub fn set_memory_space_bit(&mut self, is_on: bool) -> (){ + match is_on { + true => self.cr = self.cr | 0x2, + false => self.cr = self.cr & !0x2, + } + } + + // 2nd bit + pub fn get_bus_master_bit(&self) -> bool { + return (self.cr & 0x4) != 0; + } + pub fn set_bus_master_bit(&mut self, is_on: bool) -> (){ + match is_on { + true => self.cr = self.cr | 0x4, + false => self.cr = self.cr & !0x4, + } + } + + // 3rd bit + pub fn get_special_cycles_bit(&self) -> bool { + return (self.cr & 0x8) != 0; + } + + // 4th bit + pub fn get_memory_write_invalidate_enable_bit(&self) -> bool { + return (self.cr & 0x10) != 0; + } + + // 5th bit + pub fn get_vga_palette_snoop_bit(&self) -> bool { + return (self.cr & 0x20) != 0; + } + + // 6th bit + pub fn get_parity_err_res_bit(&self) -> bool { + return (self.cr & 0x40) != 0; + } + pub fn set_parity_err_res_bit(&mut self, is_on: bool) -> (){ + match is_on { + true => self.cr = self.cr | 0x40, + false => self.cr = self.cr & !0x40, + } + } + + // 8th bit + pub fn get_serr_enable_bit(&self) -> bool { + return (self.cr & 0x100) != 0; + } + pub fn set_serr_enable_bit(&mut self, is_on: bool) -> (){ + match is_on { + true => self.cr = self.cr | 0x100, + false => self.cr = self.cr & !0x100, + } + } + + // 9th bit + pub fn get_fast_back_to_back_enable_bit(&self) -> bool { + return (self.cr & 0x200) != 0; + } + + // 10th bit + pub fn get_interrupt_disable_bit(&self) -> bool { + return (self.cr & 0x400) != 0; + } + pub fn set_interrupt_disable_bit(&mut self, is_on: bool) -> (){ + match is_on { + true => self.cr = self.cr | 0x400, + false => self.cr = self.cr & !0x400, + } + } +} \ No newline at end of file diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs new file mode 100644 index 0000000..6148ca4 --- /dev/null +++ b/kernel/src/network/constants.rs @@ -0,0 +1,24 @@ +// Broadcast constants +pub const BROADCAST_ADDR: u32 = 0xFFFFFFFF; +pub const BROADCAST_MAC: u64 = 0xFFFFFFFFFFFF; + +// Common port numbers +pub const DHCP_CLIENT_PORT: u16 = 68; +pub const DHCP_SERVER_PORT: u16 = 67; + +// RTL device specific info +pub const RTL_VEND: u32 = 0x10EC; // Vendor ID for the rtl8139 +pub const RTL_DEV: u32 = 0x8139; // Device ID for the rtl8139 +pub const TOK: u16 = 1 << 2; // TOK bit +pub const ROK: u16 = 1 << 0; // ROK bit +pub const TRANSMIT_REG: [u32; 4] = [0x20, 0x24, 0x28, 0x2C]; +pub const TRANSMIT_CMD: [u32; 4] = [0x10, 0x14, 0x18, 0x1C]; +pub const INTERRUPT_MASK: u16 = 0x01 | 0x04 | 0x10 | 0x08 | 0x02; +pub const RX_BUFFER_SIZE: u16 = 8192; // how big the buffer is +pub const CR_RST: u16 = 0x10; // Reset, set to 1 to invoke S/W reset, held to 1 while resetting +pub const CR_RE: u8 = 0x08; // Reciever Enable, enables receiving +pub const CR_TE: u8 = 0x04; // Transmitter Enable, enables transmitting +pub const CR_BUFE: u8 = 0x01; // Rx buffer is empty +pub const CR: u32 = 0x37; // command register (1byte) +pub const CAPR: u32 = 0x38; // current address of packet read (2byte, C mode, initial value 0xFFF0) +pub const RX_READ_PTR_MASK: u16 = !0x3; // \ No newline at end of file diff --git a/kernel/src/network/devices.rs b/kernel/src/network/devices.rs index 98e05b2..d5a61b6 100644 --- a/kernel/src/network/devices.rs +++ b/kernel/src/network/devices.rs @@ -1,14 +1,24 @@ use alloc::vec::Vec; use x86_64::instructions::port::Port; -use crate::println; +use super::command_register::CommandRegister; const CONFIG_ADDRESS: u16 = 0xCF8; const CONFIG_DATA: u16 = 0xCFC; -pub struct Device {} +#[derive(Clone)] +pub struct Device { + pub bus: u8, + pub slot: u8, + pub vendor_id: u16, + pub device_id: u16, + pub class_code: PCIClassCodes, + pub sub_class: u8, + pub io_base: Option, + pub irq: Option, +} -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum PCIClassCodes { Unclassified, MassStorageController, @@ -32,7 +42,7 @@ pub enum PCIClassCodes { NonEssentialInstrumentation, CoProcessor, Reserved, - Unassigned, + Unassigned } impl PCIClassCodes { @@ -60,76 +70,164 @@ impl PCIClassCodes { 0x13 => Self::NonEssentialInstrumentation, 0x40 => Self::CoProcessor, 0xFF => Self::Unassigned, - _ => Self::Reserved, + _ => Self::Reserved } } } -// Query PCI for 16 bits of data of a device defined by (bus, slot) -// Will offset into the PCI Configuration Header by offset -// Func is used for multi-function devices... (don't think that will be used with network card) -pub fn pci_config_read_dword(bus: u8, slot: u8, func: u8, offset: u8) -> u32 { +// Write into the config address +fn create_confg_address(bus: u8, slot: u8, func: u8, offset: u8){ let lbus = bus as u32; let lslot = slot as u32; let lfunc = func as u32; - let address = - (lbus << 16) | (lslot << 11) | (lfunc << 8) | ((offset as u32) & 0xFC) | 0x80000000u32; + let address = ((lbus << 16) | (lslot << 11) | (lfunc << 8) | ((offset as u32) & 0xFC) | (0x80000000 as u32)) as u32; let mut port = Port::::new(CONFIG_ADDRESS); // Write the address unsafe { port.write(address) }; +} + +// Query PCI for 32 bits of data of a device defined by (bus, slot) +// Will offset into the PCI Configuration Header by offset +// Func is used for multi-function devices... (don't think that will be used with network card) +fn pci_config_read_dword(bus: u8, slot: u8, func: u8, offset: u8) -> u32 { + create_confg_address(bus, slot, func, offset); + let mut port = Port::::new(CONFIG_DATA); + // Read the data + let data: u32 = unsafe { port.read() }; + return data; +} +/// Query PCI to write 16 bits of data of a device defined by (bus, slot) +/// Will offset into the PCI configuration header by offset +// Func is used for multi-function devices... (don't think that will be used with network card) +fn pci_config_write_word(bus: u8, slot: u8, func: u8, offset: u8, word: u16) -> () { + create_confg_address(bus, slot, func, offset); let mut port = Port::::new(CONFIG_DATA); // Read the data let data: u32 = unsafe { port.read() }; - data + let new_data = (data & 0xFFFF0000) | word as u32; + unsafe { port.write(new_data); } } /// Check if a device exists at (bus, slot) /// If it does, it will return the vendor ID -pub fn pci_check_vendor(bus: u8, slot: u8) -> Option { +fn pci_check_vendor(bus: u8, slot: u8) -> Option { /* Try and read the first configuration register. Since there are no * vendors that == 0xFFFF, it must be a non-existent device. */ let vendor = (pci_config_read_dword(bus, slot, 0, 0) & 0xFFFF) as u16; if vendor != 0xFFFF { - println!("{}", vendor); - Some(vendor) - } else { - None + return Some(vendor); } + return None; } // Assumes a device at (bus, slot) // Will extract the class code from the configuration space -pub fn pci_get_class_code(bus: u8, slot: u8) -> PCIClassCodes { +fn pci_get_device_id(bus: u8, slot: u8) -> u16 { + let device_id = (pci_config_read_dword(bus, slot, 0, 0) >> 16) & 0xFFFF; + return device_id as u16; +} + +// Assumes a device at (bus, slot) +// Will extract the class code (class and subclass) from the configuration space +fn pci_get_class_code(bus: u8, slot: u8) -> (PCIClassCodes, u8) { let code = pci_config_read_dword(bus, slot, 0, 0x8); - let _revision_id = (code & 0xFF) as u8; - let _progif = ((code >> 8) & 0xFF) as u8; - let _subclass = ((code >> 16) & 0xFF) as u8; + // let _revision_id = (code & 0xFF) as u8; + // let _progif = ((code >> 8) & 0xFF) as u8; + let subclass = ((code >> 16) & 0xFF) as u8; let class = ((code >> 24) & 0xFF) as u8; - PCIClassCodes::from(class) + return (PCIClassCodes::from(class), subclass); +} + +// Read the interrupt line of the PCI Configuration address space +// "For the x86 architecture this register corresponds to the PIC IRQ numbers 0-15" and a value of 0xFF defines no connection +fn pci_get_irq(bus: u8, slot: u8) -> Option { + let irq = (pci_config_read_dword(bus, slot, 0, 0x3C) & 0xFF) as u8; + if irq == 0xFF { + None + } else { + Some(irq) + } +} + +fn pci_get_cmd_reg(bus: u8, slot: u8) -> CommandRegister { + let cr = pci_config_read_dword(bus, slot, 0, 0x4); + return CommandRegister::new((cr & 0xFFFF) as u16); +} + +// Set the command register +fn pci_set_cmd_reg(bus: u8, slot: u8, cr: CommandRegister) -> (){ + // Write the register with cr + pci_config_write_word(bus, slot, 0, 0x4, cr.data()); +} + +// Get io base of header type 0x0 +fn pci_get_io_base(bus: u8, slot: u8) -> Option { + let mut cr = pci_get_cmd_reg(bus, slot); + let org = cr.clone(); + // Must disable IO and memory space decode before accessing bar + cr.set_memory_space_bit(false); + cr.set_io_space_bit(false); + for i in 0..5 { + let bar = pci_config_read_dword(bus, slot, 0, 0x10 + (i * 0x4)); + if (bar & 0x1) == 0x1 { + // We have an IO address + // making sure to mask lower bits + return Some(bar & 0xFFFFFFFC); + } + } + // Restore original configuration + pci_set_cmd_reg(bus, slot, org); + return None +} + +impl Device { + // Read the command register of the device + pub fn read_command_register(&self) -> CommandRegister { + return pci_get_cmd_reg(self.bus, self.slot); + } + + // Write to the command register of the device + pub fn write_command_register(&self, new_val: CommandRegister) -> () { + pci_set_cmd_reg(self.bus, self.slot, new_val); + } } // Will return X devices // TODO: multiprocessing safety? -pub fn scan_devices() -> [Option; 3] { +pub fn scan_devices() -> Vec { let mut device_bus_slots: Vec<(u8, u8)> = Vec::new(); for bus in 0..255 { for slot in 0..31 { match pci_check_vendor(bus, slot) { - Some(_) => { - device_bus_slots.push((bus, slot)); - } - None => continue, + Some(_) => { device_bus_slots.push((bus, slot)); } + None => continue } } } - + + let mut results: Vec = Vec::new(); for bus_slot in device_bus_slots.iter() { - println!("{:?}", pci_get_class_code(bus_slot.0, bus_slot.1)); + let bus = bus_slot.0; + let slot = bus_slot.1; + let vendor_id = pci_check_vendor(bus, slot).unwrap(); + let device_id = pci_get_device_id(bus, slot); + let class_subclass = pci_get_class_code(bus, slot); + let irq = pci_get_irq(bus, slot); + let io_base = pci_get_io_base(bus, slot); + results.push(Device { + bus, + slot, + vendor_id, + device_id, + class_code: class_subclass.0, + sub_class: class_subclass.1, + irq, + io_base + }); } - let results = [None, None, None]; - results -} + return results; +} \ No newline at end of file diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs new file mode 100644 index 0000000..2e1e903 --- /dev/null +++ b/kernel/src/network/dhcp.rs @@ -0,0 +1,162 @@ +use alloc::vec; + +use super::{ + bytefield::{Bytefield128, Bytefield16, Bytefield32, Bytefield8, Bytefield48}, + layer::{Layer, HasChecksum, LayerType}, + udp::UDPPacket, ip::IPPacket, +}; + +struct WrappedU32 { + data: u32, +} + +impl WrappedU32 { + pub fn get(&self) -> u32 { + self.data + } + pub fn set(&mut self, data: u32) { + self.data = data; + } +} + +static mut ID_GEN: spin::Mutex = spin::Mutex::new(WrappedU32 { data: 0 }); +#[derive(Debug)] +pub struct DHCPPacket { + pub udp_packet: UDPPacket, // public for checksumming + op_code: u8, + hardware_type: u8, + hardware_address_length: u8, + hops: u8, + transaction_identifier: Bytefield32, + seconds: Bytefield16, + flags: Bytefield16, + pub client_ip: Bytefield32, + pub my_ip: Bytefield32, + pub server_ip: Bytefield32, + pub gateway_ip: Bytefield32, + client_hardware_address: Bytefield128, + sname: [u8; 64], + file: [u8; 128], + options: [u8; 64], // variable length +} + +impl DHCPPacket { + pub fn new() -> Self { + DHCPPacket { + udp_packet: UDPPacket::new(), + op_code: 0, // 1 byte + hardware_type: 0, // 1 byte + hardware_address_length: 0, // 1 byte + hops: 0, // 1 byte + transaction_identifier: Bytefield32::new(0), // 4 bytes + seconds: Bytefield16::new(0), // 2 bytes + flags: Bytefield16::new(0), // 2 bytes + client_ip: Bytefield32::new(0), // 4 bytes + my_ip: Bytefield32::new(0), // 4 bytes + server_ip: Bytefield32::new(0), // 4 bytes + gateway_ip: Bytefield32::new(0), // 4 bytes + client_hardware_address: Bytefield128::new(0), // 16 bytes + sname: [0; 64], // 64 bytes + file: [0; 128], // 128 bytes + options: [0; 64], // 64 bytes (can be more, todo!) + // 300 bytes total + } + } + + pub fn gen(udp_packet: UDPPacket, ip_address: Option, mac_address: u64) -> Self { + let identification = unsafe { + let mut id_gen = ID_GEN.lock(); + let id_gen_old = id_gen.get(); + id_gen.set((id_gen_old + 1) % 0xFFFF); + Bytefield32::new(id_gen.get()) + }; + let mac = Bytefield48::new(mac_address); + let mut client_hardware_address = Bytefield128::new(0); + for i in 0..6 { + client_hardware_address[i] = mac[i]; + } + let mut dhcp = DHCPPacket { + udp_packet, + op_code: 1, // 1 for is request + hardware_type: 1, // ethernet is 1 + hardware_address_length: 6, // corresponds with hardware address length + hops: 0, // must be 0 from client + transaction_identifier: identification, // random ID + seconds: Bytefield16::new(0), // doesn't matter + flags: Bytefield16::new(if ip_address.is_none() { 0x1 } else { 0x0 }), // 1 indicates that I don't know my IP address + client_ip: Bytefield32::new(ip_address.unwrap_or(0)), // 0 since + my_ip: Bytefield32::new(0), // 4 bytes + server_ip: Bytefield32::new(0), // 4 bytes + gateway_ip: Bytefield32::new(0), // 4 bytes + client_hardware_address, // 16 bytes + sname: [0; 64], // 64 bytes + file: [0; 128], // 128 bytes + options: [0; 64], // todo: 64 bytes (can be more) + // 300 bytes total + }; + let data = dhcp.serialize(); + let start_udp = data.len() - (DHCPPacket::packet_size() as usize + UDPPacket::packet_size() as usize); + let start_ip = start_udp - (IPPacket::packet_size() as usize); + dhcp.udp_packet.ip_packet.calculate_checksum(&data[start_ip..start_udp]); + dhcp.udp_packet.calculate_checksum(&data[start_udp..]); + return dhcp; + } +} + +impl Layer for DHCPPacket { + type Input = UDPPacket; + fn parse(udp_layer: UDPPacket, bytevec: &[u8]) -> (Self, usize, LayerType) + where + Self: Sized, + { + let mut packet = DHCPPacket::new(); // create an empty packet + let mut i = 0; + packet.udp_packet = udp_layer; + packet.op_code = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.hardware_type = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.hops = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.transaction_identifier = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.seconds = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.flags = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.client_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.my_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.server_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.gateway_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.client_hardware_address = Bytefield128::read_inc(&bytevec[i..], &mut i); + i += 64; // sname + i += 128; // file + i += 64; // options + // Ignoring those parts of the DHCP packet. + let left_to_parse = packet.udp_packet.length.val() - 308; + i += left_to_parse as usize; + assert!(i >= 300); // 300 bytes + return (packet, i, LayerType::UNDEF); + } + + fn serialize(&self) -> alloc::vec::Vec { + let mut res = vec![]; + res.extend(self.udp_packet.serialize()); + res.push(self.op_code); + res.push(self.hardware_type); + res.push(self.hardware_address_length); + res.push(self.hops); + res.extend(self.transaction_identifier.data); + res.extend(self.seconds.data); + res.extend(self.flags.data); + res.extend(self.client_ip.data); + res.extend(self.my_ip.data); + res.extend(self.server_ip.data); + res.extend(self.gateway_ip.data); + res.extend(self.client_hardware_address.data); + res.extend(self.sname); + res.extend(self.file); + res.extend(self.options); + assert!(res.len() == (300 + self.udp_packet.serialize().len())); + res + } + + fn packet_size() -> u16 { + 300 + } +} \ No newline at end of file diff --git a/kernel/src/network/e1000.rs b/kernel/src/network/e1000.rs new file mode 100644 index 0000000..08528d1 --- /dev/null +++ b/kernel/src/network/e1000.rs @@ -0,0 +1,211 @@ +use alloc::boxed::Box; + +use super::devices::Device; + +const INTEL_VEND: u32 = 0x8086; // Vendor ID for Intel +const E1000_DEV: u32 = 0x100E; // Device ID for the e1000 Qemu, Bochs, and VirtualBox emmulated NICs +const E1000_I217: u32 = 0x153A; // Device ID for Intel I217 +const E1000_82577LM: u32 = 0x10EA; // Device ID for Intel 82577LM + +// Constants From https://wiki.osdev.org/Intel_Ethernet_i217 +const REG_CTRL: u32 = 0x0000; +const REG_STATUS: u32 = 0x0008; +const REG_EEPROM: u32 = 0x0014; +const REG_CTRL_EXT: u32 = 0x0018; +const REG_IMASK: u32 = 0x00D0; +const REG_RCTRL: u32 = 0x0100; +const REG_RXDESCLO: u32 = 0x2800; +const REG_RXDESCHI: u32 = 0x2804; +const REG_RXDESCLEN: u32 = 0x2808; +const REG_RXDESCHEAD: u32 = 0x2810; +const REG_RXDESCTAIL: u32 = 0x2818; + +const REG_TCTRL: u32 = 0x0400; +const REG_TXDESCLO: u32 = 0x3800; +const REG_TXDESCHI: u32 = 0x3804; +const REG_TXDESCLEN: u32 = 0x3808; +const REG_TXDESCHEAD: u32 = 0x3810; +const REG_TXDESCTAIL: u32 = 0x3818; + +const REG_RDTR: u32 = 0x2820; // RX Delay Timer Register +const REG_RXDCTL: u32 = 0x2828; // RX Descriptor Control +const REG_RADV: u32 = 0x282C; // RX Int. Absolute Delay Timer +const REG_RSRPD: u32 = 0x2C00; // RX Small Packet Detect Interrupt + +const REG_TIPG: u32 = 0x0410; // Transmit Inter Packet Gap +const ECTRL_SLU: u32 = 0x40; //set link up + +const RCTL_EN: u32 = 1 << 1; // Receiver Enable +const RCTL_SBP: u32 = 1 << 2; // Store Bad Packets +const RCTL_UPE: u32 = 1 << 3; // Unicast Promiscuous Enabled +const RCTL_MPE: u32 = 1 << 4; // Multicast Promiscuous Enabled +const RCTL_LPE: u32 = 1 << 5; // Long Packet Reception Enable +const RCTL_LBM_NONE: u32 = 0 << 6; // No Loopback +const RCTL_LBM_PHY: u32 = 3 << 6; // PHY or external SerDesc loopback +const RTCL_RDMTS_HALF: u32 = 0 << 8; // Free Buffer Threshold is 1/2 of RDLEN +const RTCL_RDMTS_QUARTER: u32 = 1 << 8; // Free Buffer Threshold is 1/4 of RDLEN +const RTCL_RDMTS_EIGHTH: u32 = 2 << 8; // Free Buffer Threshold is 1/8 of RDLEN +const RCTL_MO_36: u32 = 0 << 12; // Multicast Offset - bits 47:36 +const RCTL_MO_35: u32 = 1 << 12; // Multicast Offset - bits 46:35 +const RCTL_MO_34: u32 = 2 << 12; // Multicast Offset - bits 45:34 +const RCTL_MO_32: u32 = 3 << 12; // Multicast Offset - bits 43:32 +const RCTL_BAM: u32 = 1 << 15; // Broadcast Accept Mode +const RCTL_VFE: u32 = 1 << 18; // VLAN Filter Enable +const RCTL_CFIEN: u32 = 1 << 19; // Canonical Form Indicator Enable +const RCTL_CFI: u32 = 1 << 20; // Canonical Form Indicator Bit Value +const RCTL_DPF: u32 = 1 << 22; // Discard Pause Frames +const RCTL_PMCF: u32 = 1 << 23; // Pass MAC Control Frames +const RCTL_SECRC: u32 = 1 << 26; // Strip Ethernet CRC + +// Buffer Sizes +const RCTL_BSIZE_256: u32 = 3 << 16; +const RCTL_BSIZE_512: u32 = 2 << 16; +const RCTL_BSIZE_1024: u32 = 1 << 16; +const RCTL_BSIZE_2048: u32 = 0 << 16; +const RCTL_BSIZE_4096: u32 = (3 << 16) | (1 << 25); +const RCTL_BSIZE_8192: u32 = (2 << 16) | (1 << 25); +const RCTL_BSIZE_16384: u32 = (1 << 16) | (1 << 25); + +// Transmit Command +const CMD_EOP: u32 = 1 << 0; // End of Packet +const CMD_IFCS: u32 = 1 << 1; // Insert FCS +const CMD_IC: u32 = 1 << 2; // Insert Checksum +const CMD_RS: u32 = 1 << 3; // Report Status +const CMD_RPS: u32 = 1 << 4; // Report Packet Sent +const CMD_VLE: u32 = 1 << 6; // VLAN Packet Enable +const CMD_IDE: u32 = 1 << 7; // Interrupt Delay Enable + +// TCTL Register +const TCTL_EN: u32 = 1 << 1; // Transmit Enable +const TCTL_PSP: u32 = 1 << 3; // Pad Short Packets +const TCTL_CT_SHIFT: u32 = 4; // Collision Threshold +const TCTL_COLD_SHIFT: u32 = 12; // Collision Distance +const TCTL_SWXOFF: u32 = 1 << 22; // Software XOFF Transmission +const TCTL_RTLC: u32 = 1 << 24; // Re-transmit on Late Collision + +const TSTA_DD: u32 = 1 << 0; // Descriptor Done +const TSTA_EC: u32 = 1 << 1; // Excess Collisions +const TSTA_LC: u32 = 1 << 2; // Late Collision +const LSTA_TU: u32 = 1 << 3; // Transmit Underrun + +#[repr(C)] +struct e1000_rx_desc { + volatile uint64_t addr; + volatile uint16_t length; + volatile uint16_t checksum; + volatile uint8_t status; + volatile uint8_t errors; + volatile uint16_t special; +} __attribute__((packed)); + +#[repr(C)] +struct e1000_tx_desc { + volatile uint64_t addr; + volatile uint16_t length; + volatile uint8_t cso; + volatile uint8_t cmd; + volatile uint8_t status; + volatile uint8_t css; + volatile uint16_t special; +} __attribute__((packed)); + + +struct E1000 { + // Type of BAR0 + bar_type: u8, + + // IO Base Address + io_base: u16, + + // MMIO Base Address + mem_base: u64, + + // A flag indicating if eeprom exists + eerprom_exists: bool, + + // A buffer for storing the MAC address + mac: [u8; 6], + + struct e1000_rx_desc *rx_descs[E1000_NUM_RX_DESC]; // Receive Descriptor Buffers + struct e1000_tx_desc *tx_descs[E1000_NUM_TX_DESC]; // Transmit Descriptor Buffers + + // Current Receive Descriptor Buffer + rx_cur: u16, + + // Current Transmit Descriptor Buffer + tx_cur: u16, +} + +impl E1000 { + // Constructor. takes as a parameter a pointer to an object that encapsulate all he PCI configuration data of the device + pub fn new(pci_config: Device) -> E1000 { + unimplemented!("Unimplemented"); + } + + // Perform initialization tasks and starts the driver + pub fn init() -> () { + unimplemented!("Unimplemented"); + } + + // This method should be called by the interrupt handler + pub fn fire(p_interruptContext: InterruptContext) -> () { + unimplemented!("Unimplemented"); + } + + // Returns the MAC address + pub fn get_mac_address() -> [u8; 6]{ + unimplemented!("Unimplemented"); + } + + pub fn send_packet(p_data: Box<()>, p_len: u16) -> u32 { + unimplemented!("Unimplemented"); + } + + // Send Commands and read results From NICs either using MMIO or IO Ports + fn write_command(p_address: u16, p_value: u32) -> () { + unimplemented!("Unimplemented"); + } + fn send_command(p_address: u16) -> () { + unimplemented!("Unimplemented"); + } + + // Detect if EE Prom exists + fn detect_ee_prom() -> bool { + unimplemented!("Unimplemented"); + } + + // Read 4 bytes from a specific EEProm Address + fn eeprom_read(addr: u8) -> u32 { + unimplemented!("Unimplemented"); + } + + // Read MAC Address + fn read_mac_address() -> () { + unimplemented!("Unimplemented"); + } + + // Start up the network + fn start_link() -> () { + unimplemented!("Unimplemented"); + } + + // Initialize receive descriptors an buffers + fn rxinit() -> () { + unimplemented!("Unimplemented"); + } + + // Initialize transmit descriptors an buffers + fn txinit() -> () { + unimplemented!("Unimplemented"); + } + + // Enable interrupts + fn enable_interrupts() -> () { + unimplemented!("Unimplemented"); + } + + // Handle a packet reception + fn handle_receive() -> () { + unimplemented!("Unimplemented"); + } +} \ No newline at end of file diff --git a/kernel/src/network/ethernet.rs b/kernel/src/network/ethernet.rs new file mode 100644 index 0000000..f52f9ff --- /dev/null +++ b/kernel/src/network/ethernet.rs @@ -0,0 +1,91 @@ +use super::{ + bytefield::{Bytefield16, Bytefield48}, + layer::{EmptyLayer, Layer, LayerType}, +}; +use alloc::vec; +use alloc::vec::Vec; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u16)] +pub enum EthType { + Arp = 0x0806, + IPv4 = 0x0800, + RoCE = 0x8915, + Unknown = 0, +} + +impl EthType { + pub fn from(packet_type: u16) -> Self { + match packet_type { + 0x0806 => Self::Arp, + 0x0800 => Self::IPv4, + _ => Self::Unknown, + } + } + + pub fn as_bytefield(&self) -> Bytefield16 { + Bytefield16::new(*self as u16) + } +} + +// Total size is 14 bytes +#[derive(Debug)] +pub struct EthernetPacket { + pub dest_mac: Bytefield48, // u48 + pub src_mac: Bytefield48, // u48, + pub packet_type: EthType, // u16 +} + +impl EthernetPacket { + pub fn new() -> Self { + EthernetPacket { + dest_mac: Bytefield48::new(0), + src_mac: Bytefield48::new(0), + packet_type: EthType::Unknown, + } + } + + pub fn gen(destination_mac: u64, source_mac: u64, packet_type: EthType) -> Self { + EthernetPacket { + dest_mac: Bytefield48::new(destination_mac), + src_mac: Bytefield48::new(source_mac), + packet_type, + } + } +} + +impl Layer for EthernetPacket { + type Input = EmptyLayer; + fn parse(_empty_layer: EmptyLayer, bytevec: &[u8]) -> (Self, usize, LayerType) + where + Self: Sized, + { + let mut packet = EthernetPacket::new(); // create an empty packet + // Read 14 bytes + let mut i = 0; + packet.dest_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); + packet.src_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); + packet.packet_type = EthType::from(Bytefield16::read_inc(&bytevec[i..], &mut i).val()); + assert!(i == 14); // 14 bytes + let layer_type = match &packet.packet_type { + EthType::Arp => LayerType::ARP, + EthType::IPv4 => LayerType::IP, + EthType::RoCE => LayerType::UNDEF, + EthType::Unknown => LayerType::UNDEF, + }; + return (packet, i, layer_type); + } + + fn serialize(&self) -> Vec { + let mut res = vec![]; + res.extend(self.dest_mac.data); + res.extend(self.src_mac.data); + res.extend(self.packet_type.as_bytefield().data); + assert!(res.len() == 14); + res + } + + fn packet_size() -> u16 { + 14 + } +} diff --git a/kernel/src/network/icmp.rs b/kernel/src/network/icmp.rs new file mode 100644 index 0000000..e69de29 diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs new file mode 100644 index 0000000..eb022e3 --- /dev/null +++ b/kernel/src/network/init.rs @@ -0,0 +1,63 @@ +use crate::network::dhcp::DHCPPacket; +use crate::network::ethernet::{EthernetPacket, EthType}; +use crate::network::ip::{IPPacket, Protocol}; +use crate::network::rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}; +use crate::network::udp::UDPPacket; +use crate::{network::constants::DHCP_SERVER_PORT, println}; +use crate::network::layer::{LayerType, Layer}; +use crate::network::socket::RawSocket; +use super::constants::{BROADCAST_ADDR, DHCP_CLIENT_PORT, BROADCAST_MAC}; + +pub fn init() -> (){ + // todo bundle the init phases +} + +pub fn init_dhcp(wait_timeout: u8) -> bool { + // open a socket (shouldn't fail since we are in the init process) + let mut socket = RawSocket::new(DHCP_CLIENT_PORT).unwrap(); + + disable_network_interrupts(); + let rtl_dev_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); + + // send dhcp initial request + let eth = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::IPv4); + let ip_size = DHCPPacket::packet_size() + UDPPacket::packet_size(); + let ip = IPPacket::gen(eth, ip_size, Protocol::UDP, 0x0, BROADCAST_ADDR); + let udp = UDPPacket::gen(ip, DHCP_CLIENT_PORT, DHCP_SERVER_PORT, DHCPPacket::packet_size()); + let dhcp = DHCPPacket::gen(udp, None, rtl_dev_info.mac_address.unwrap()); + let packet_data = dhcp.serialize(); + + rtl_dev_info.send_packet(&packet_data); // send first packet + drop(rtl_dev_guard); + enable_network_interrupts(); + println!("Sent packet"); + + // get response + let mut timeout = 0; + let pkt_data; + loop { + if let Some(dhcp_res) = socket.get_packet_with_timeout(1) { + println!("Got good packet"); + pkt_data = dhcp_res; + break; + } + timeout += 1; + if timeout == wait_timeout { + socket.close(); + return false; + } + } + + socket.close(); + disable_network_interrupts(); + let mut rtl_dev_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_guard.get_mut().unwrap(); + if pkt_data.get_type() == LayerType::DHCP { + let dhcp_res = pkt_data.unwrap_dhcp(); + rtl_dev_info.my_ip_address = Some(dhcp_res.my_ip.val()); + rtl_dev_info.dhcp_server_ip = Some(dhcp_res.server_ip.val()); + } + enable_network_interrupts(); + return true; +} diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index cfdc1ee..25161d2 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -1,9 +1,12 @@ +use alloc::vec::Vec; +use alloc::vec; + use super::{ - bitfield::{Bitfield16, Bitfield32}, - packet::Packet, + bytefield::{Bytefield8, Bytefield16, Bytefield32}, + layer::{Layer, HasChecksum, LayerType}, ethernet::EthernetPacket, }; -#[derive(Clone, Copy)] +#[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum Protocol { ICMP = 1, @@ -19,82 +22,171 @@ impl Protocol { 1 => Self::ICMP, 6 => Self::TCP, 17 => Self::UDP, - 27 => Self::RDP, _ => Self::Unsupported, } } + + pub fn as_byte(&self) -> u8 { + *self as u8 + } } struct WrappedU16 { - data: u16 + data: u16, } impl WrappedU16 { - pub fn get(&self) -> u16 { self.data } + pub fn get(&self) -> u16 { + self.data + } pub fn set(&mut self, data: u16) { self.data = data; } } static mut ID_GEN: spin::Mutex = spin::Mutex::new(WrappedU16 { data: 0 }); - -struct IPPacket { - version_hlen: u8, // 1 byte - type_of_service: u8, // 1 byte - total_length: Bitfield16, // 2 bytes - identification: Bitfield16, // 2 bytes - flags_fragment_offset: Bitfield16, // 2 bytes - ttl: u8, // 1 byte - protocol: Protocol, // 1 byte - checksum: Bitfield16, // 2 bytes - source_ip: Bitfield32, // 4 bytes - destination_ip: Bitfield32, // 4 bytes - // 20 bytes in total +#[derive(Debug)] +pub struct IPPacket { + pub ethernet_packet: EthernetPacket, + version_hlen: u8, // 1 byte + type_of_service: u8, // 1 byte + pub total_length: Bytefield16, // 2 bytes (public for checksumming) + identification: Bytefield16, // 2 bytes + flags_fragment_offset: Bytefield16, // 2 bytes + ttl: u8, // 1 byte + pub protocol: Protocol, // 1 byte (public for checksumming) + pub checksum: Bytefield16, // 2 bytes + pub source_ip: Bytefield32, // 4 bytes (public for checksumming) + pub destination_ip: Bytefield32, // 4 bytes (public for checksumming) + // 20 bytes in total } impl IPPacket { pub fn new() -> Self { IPPacket { + ethernet_packet: EthernetPacket::new(), version_hlen: 0, type_of_service: 0, - total_length: Bitfield16::new(0), - identification: Bitfield16::new(0), - flags_fragment_offset: Bitfield16::new(0), + total_length: Bytefield16::new(0), + identification: Bytefield16::new(0), + flags_fragment_offset: Bytefield16::new(0), ttl: 0, protocol: Protocol::Unsupported, - checksum: Bitfield16::new(0), - source_ip: Bitfield32::new(0), - destination_ip: Bitfield32::new(0), + checksum: Bytefield16::new(0), + source_ip: Bytefield32::new(0), + destination_ip: Bytefield32::new(0), } } - pub fn gen(data_length: u16, protocol: Protocol, src_ip: u32, dst_ip: u32) -> Self { - let identification = unsafe { + pub fn gen(ethernet_packet: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dst_ip: u32) -> Self { + let identification = unsafe { let mut id_gen = ID_GEN.lock(); - id_gen.set((id_gen.get() + 1) % 0xFFFF); - Bitfield16::new(id_gen.get()) + let id_gen_old = id_gen.get(); + id_gen.set((id_gen_old + 1) % 0xFFFF); + Bytefield16::new(id_gen.get()) }; IPPacket { + ethernet_packet, version_hlen: 0x45, type_of_service: 0x0, - total_length: Bitfield16::new(data_length + 20), // adding data length and size of IP packet + total_length: Bytefield16::new(data_length + 20), // adding data length and size of IP packet identification, - flags_fragment_offset: Bitfield16::new(0), + flags_fragment_offset: Bytefield16::new(0), ttl: 120, protocol, - checksum: todo!(), // todo! compute checksum - source_ip: Bitfield32::new(src_ip), - destination_ip: Bitfield32::new(dst_ip), + checksum: Bytefield16::new(0), + source_ip: Bytefield32::new(src_ip), + destination_ip: Bytefield32::new(dst_ip), } } } -impl Packet for IPPacket { - fn parse(bytevec: &[u8]) -> (Self, usize) where Self: Sized, { - +impl Layer for IPPacket { + type Input = EthernetPacket; + fn parse(ethernet_layer: EthernetPacket, bytevec: &[u8]) -> (Self, usize, LayerType) + where + Self: Sized, + { + let mut packet = IPPacket::new(); // create an empty packet + // Read 20 bytes + let mut i = 0; + packet.ethernet_packet = ethernet_layer; + packet.version_hlen = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.type_of_service = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.total_length = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.identification = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.flags_fragment_offset = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.ttl = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + let protocol = Bytefield8::read_inc(&bytevec[i..], &mut i); + packet.protocol = Protocol::from(protocol.data); + packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.source_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.destination_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + assert!(i == 20); // 20 bytes + let layer_type = match packet.protocol { + Protocol::ICMP => LayerType::ICMP, + Protocol::TCP => LayerType::TCP, + Protocol::UDP => LayerType::UDP, + Protocol::RDP => LayerType::UNDEF, + Protocol::Unsupported => LayerType::UNDEF, + }; + return (packet, i, layer_type); + } + + fn serialize(&self) -> Vec { + let mut res = vec![]; + res.extend(self.ethernet_packet.serialize()); + res.push(self.version_hlen); + res.push(self.type_of_service); + res.extend(self.total_length.data); + res.extend(self.identification.data); + res.extend(self.flags_fragment_offset.data); + res.push(self.ttl); + res.push(self.protocol.as_byte()); + res.extend(self.checksum.data); + res.extend(self.source_ip.data); + res.extend(self.destination_ip.data); + assert!(res.len() == (20 + self.ethernet_packet.serialize().len())); + res } - fn serialize(&self) -> alloc::vec::Vec { - + fn packet_size() -> u16 { + 20 } } + +impl HasChecksum for IPPacket { + fn calculate_checksum(&mut self, data: &[u8]) -> () { + // Starting vars + let mut sum: u32 = 0; + + // Sum the body + self.checksum = Bytefield16::new(0); + let mut ptr = 0; + let mut ip_len = data.len(); + while ip_len > 1 { + sum += (data[ptr] as u32) | ((data[ptr + 1] as u32) << 8); + ip_len -= 2; + ptr += 2; + } + + if data.len() % 2 == 1 { + // Add the padding if the packet length is odd + sum += (data[ptr] as u32) << 8; + } + + // Add the carries + while sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16); + } + + // One's complement + let mut res = !sum as u16; + // Swap the bytes because we did our sum in big endian + // (and the bytefield will try to convert to big endian) + res = ((res >> 8) & 0xFF) | ((res & 0xFF) << 8); + + // Return the one's complement of sum + self.checksum = Bytefield16::new(res); + } +} \ No newline at end of file diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs new file mode 100644 index 0000000..ce42f4f --- /dev/null +++ b/kernel/src/network/layer.rs @@ -0,0 +1,172 @@ +use alloc::vec; +use alloc::vec::Vec; + +use super::arp::ArpPacket; +use super::dhcp::DHCPPacket; +use super::ethernet::EthernetPacket; +use super::ip::IPPacket; +use super::udp::UDPPacket; + +pub trait Layer { + type Input; + + /// Parse a bytevec (vector slice) and return the packet and the amount of data parsed + fn parse(upper: Self::Input, bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized; + + /// Convert a packet into a vector of bytes + fn serialize(&self) -> Vec; + + /// The size of the packet + fn packet_size() -> u16; +} +#[derive(Debug)] +pub struct EmptyLayer {} +impl EmptyLayer { + pub fn new() -> Self { + EmptyLayer { } + } +} + +impl Layer for EmptyLayer { + type Input = EmptyLayer; + fn parse(_upper: EmptyLayer, _bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized { + (Self {}, 0, LayerType::UNDEF) + } + + fn serialize(&self) -> Vec { + vec![] + } + + fn packet_size() -> u16 { + 0 + } +} + +pub trait HasChecksum { + /// Calculate the checksum and self mutate + fn calculate_checksum(&mut self, data: &[u8]) -> (); +} + +#[derive(Debug, PartialEq, Eq)] +pub enum LayerType { + ETH, + IP, + ARP, + UDP, + ICMP, + DHCP, + TCP, + UNDEF // the default layer type +} + +/// Wrapper type to allow me to return a generic +#[derive(Debug)] +pub enum PacketData { + ETH(EthernetPacket), + IP(IPPacket), + ARP(ArpPacket), + UDP(UDPPacket), + ICMP(EmptyLayer), + DHCP(DHCPPacket), + TCP(EmptyLayer), + UNDEF(EmptyLayer) +} + +impl PacketData { + pub fn unwrap_eth(self) -> EthernetPacket { + match self { + PacketData::ETH(val) => val, + _ => unreachable!("Mismatched type. Couldn't unwrap"), + } + } + pub fn unwrap_ip(self) -> IPPacket { + match self { + PacketData::IP(val) => val, + _ => unreachable!("Mismatched type. Couldn't unwrap"), + } + } + pub fn unwrap_arp(self) -> ArpPacket { + match self { + PacketData::ARP(val) => val, + _ => unreachable!("Mismatched type. Couldn't unwrap"), + } + } + pub fn unwrap_udp(self) -> UDPPacket { + match self { + PacketData::UDP(val) => val, + _ => unreachable!("Mismatched type. Couldn't unwrap"), + } + } + pub fn unwrap_dhcp(self) -> DHCPPacket { + match self { + PacketData::DHCP(val) => val, + _ => unreachable!("Mismatched type. Couldn't unwrap"), + } + } + pub fn unwrap_undef(self) -> EmptyLayer { + match self { + PacketData::UNDEF(val) => val, + _ => unreachable!("Mismatched type. Couldn't unwrap"), + } + } + pub fn get_type(&self) -> LayerType { + match self { + PacketData::ETH(_) => LayerType::ETH, + PacketData::IP(_) => LayerType::IP, + PacketData::ARP(_) => LayerType::ARP, + PacketData::UDP(_) => LayerType::UDP, + PacketData::ICMP(_) => LayerType::ICMP, + PacketData::DHCP(_) => LayerType::DHCP, + PacketData::TCP(_) => LayerType::TCP, + PacketData::UNDEF(_) => LayerType::UNDEF, + } + } +} + +pub fn full_parse(packet: &Vec) -> (usize, PacketData) { + let mut i = 0; + let mut last_layer = PacketData::UNDEF(EmptyLayer::new()); + let mut next_type = LayerType::ETH; + loop { + match next_type { + LayerType::ETH => { + let last_layer_data = last_layer.unwrap_undef(); + let (eth_layer, size, network_layer_type) = EthernetPacket::parse(last_layer_data, &packet[i..]); + last_layer = PacketData::ETH(eth_layer); + i += size; + next_type = network_layer_type; + }, + LayerType::IP => { + let last_layer_data = last_layer.unwrap_eth(); + let (ip_layer, size, transport_layer_type) = IPPacket::parse(last_layer_data, &packet[i..]); + last_layer = PacketData::IP(ip_layer); + i += size; + next_type = transport_layer_type; + }, + LayerType::ARP => { + let last_layer_data = last_layer.unwrap_eth(); + let (arp_layer, size, transport_layer_type) = ArpPacket::parse(last_layer_data, &packet[i..]); + last_layer = PacketData::ARP(arp_layer); + i += size; + next_type = transport_layer_type; + }, + LayerType::UDP => { + let last_layer_data = last_layer.unwrap_ip(); + let (udp_layer, size, application_layer_type) = UDPPacket::parse(last_layer_data, &packet[i..]); + last_layer = PacketData::UDP(udp_layer); + i += size; + next_type = application_layer_type; + }, + LayerType::ICMP => { return (0, PacketData::UNDEF(EmptyLayer::new())); }, + LayerType::DHCP =>{ + let last_layer_data = last_layer.unwrap_udp(); + let (dhcp_layer, size, empty_type) = DHCPPacket::parse(last_layer_data, &packet[i..]); + last_layer = PacketData::DHCP(dhcp_layer); + i += size; + next_type = empty_type; + }, + LayerType::TCP => { return (0, PacketData::UNDEF(EmptyLayer::new())); }, + LayerType::UNDEF => { return (i, last_layer); }, + } + }; +} \ No newline at end of file diff --git a/kernel/src/network/mod.rs b/kernel/src/network/mod.rs index 2f0fe75..4e6bd7b 100644 --- a/kernel/src/network/mod.rs +++ b/kernel/src/network/mod.rs @@ -1,12 +1,29 @@ +pub mod bytefield; +pub mod arp; +pub mod command_register; pub mod devices; +pub mod ethernet; +pub mod ip; +pub mod layer; +pub mod rtl8139; +pub mod udp; +pub mod dhcp; +pub mod socket; +pub mod init; +// todo: remove pub until things break... +mod raw_array; +mod arp_table; +mod netsync; +pub mod constants; +// pub mod e1000; /* pub struct NetworkIO { } impl NetworkIO { - // Create a new NetworkIO instance + // Create a new NetworkIO instance (to use all possible network drivers? IDK) fn new() -> Self { return NetworkIO { diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs new file mode 100644 index 0000000..ecc7d41 --- /dev/null +++ b/kernel/src/network/netsync.rs @@ -0,0 +1,62 @@ +use spin::MutexGuard; + +use crate::println; + +use super::rtl8139::{RTL8139, disable_network_interrupts, enable_network_interrupts}; + +pub struct NetworkInterruptsGuard<'a> { + data: MutexGuard<'a, Option> +} + +impl NetworkInterruptsGuard<'_> { + pub fn get_mut(&mut self) -> Option<&mut RTL8139> { + return self.data.as_mut(); + } + + pub fn get_ref(&self) -> Option<&RTL8139> { + return self.data.as_ref(); + } +} + +impl Drop for NetworkInterruptsGuard<'_> { + fn drop(&mut self) { + // re-enable network interrupts when we drop + // drop(self.data); + // enable_network_interrupts(); + // println!("Enabling network interrupts") + } +} + +pub struct SafeRTL8139 { + data: spin::Mutex> +} + +impl SafeRTL8139 { + pub fn new(data: spin::Mutex>) -> Self { + Self { data } + } + pub fn lock(&self) -> NetworkInterruptsGuard { + // disable_network_interrupts(); + return NetworkInterruptsGuard { data: self.data.lock() } + } + pub fn lock_no_disable(&self) -> MutexGuard> { + return self.data.lock(); + } +} + + +pub struct InterruptCounter { + pub data: u32, +} +impl InterruptCounter { + pub fn get(&self) -> u32 { + return self.data; + } + pub fn inc(&mut self) -> () { + self.data += 1; + } + + pub fn dec(&mut self) -> () { + self.data -= 1; + } +} diff --git a/kernel/src/network/raw_array.rs b/kernel/src/network/raw_array.rs new file mode 100644 index 0000000..f00cec3 --- /dev/null +++ b/kernel/src/network/raw_array.rs @@ -0,0 +1,98 @@ +use core::ops::Index; +use alloc::vec; +use alloc::vec::Vec; + +// Leaving this here unless we change an implementation that necessiates a differnt type of array +/*pub struct RawArray { + start: *const u8 +} + +impl RawArray { + /// An infinite array beginning at "start" + pub fn new(start: *const u8) -> Self { + RawArray { + start + } + } + + // Ignore values + pub fn shift_amount(&mut self, amount: usize) -> () { + self.start = unsafe { self.start.add(amount) }; + } + + // Move the array forward, "consuming" those values + pub fn trim(&mut self, amount: usize) -> Vec { + let mut res = vec![]; + for _ in 0..amount { + unsafe { + res.push(*self.start); + self.start = self.start.add(1); + } + } + res + } + +} + +impl Index for RawArray { + type Output = u8; + /// Index into the infinite array using raw pointers + fn index<'a>(&'a self, i: usize) -> &u8 { + unsafe { &(*self.start.add(i)) } + } +}*/ + + +pub struct WrappingRawArray { + start: *const u8, + pos: usize, + size: usize, +} + +impl WrappingRawArray { + /// An infinite array beginning at "start" and wrapping after size bytes + pub fn new(start: *const u8, size: usize) -> Self { + WrappingRawArray { + start, + pos: 0, + size, + } + } + + // Ignore values + pub fn shift_amount(&mut self, amount: usize) -> () { + self.pos = (self.pos + amount) % self.size; + } + + // Move the array forward, "consuming" those values + pub fn trim(&mut self, amount: usize) -> Vec { + let mut res = vec![]; + let mut tmp_start = unsafe { self.start.add(self.pos) }; + for _ in 0..amount { + // append the byte and move tmp_start forward + unsafe { + res.push(*tmp_start); + tmp_start = tmp_start.add(1); + } + // also increment the position + self.pos += 1; + // if the position is equal to size (we are outside the buffer) + if self.pos == self.size { + // so we reset to the beginning + self.pos = 0; + tmp_start = self.start; + } + } + res + } + +} + +impl Index for WrappingRawArray { + type Output = u8; + /// Index into the infinite array using raw pointers + fn index<'a>(&'a self, i: usize) -> &u8 { + unsafe { &(*self.start.add(i % self.size)) } + } +} + diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs new file mode 100644 index 0000000..3104c5a --- /dev/null +++ b/kernel/src/network/rtl8139.rs @@ -0,0 +1,464 @@ +use alloc::{vec, collections::VecDeque}; +use alloc::vec::Vec; +use hashbrown::{HashMap, HashSet}; +use lazy_static::lazy_static; + +use x86_64::{ + instructions::port::Port, + structures::{ + idt::InterruptStackFrame, + paging::{FrameAllocator, PhysFrame}, + }, + PhysAddr, VirtAddr, +}; + +use crate::interrupts::IDT; + +use crate::network::constants::{RX_BUFFER_SIZE, CR, CR_RE, CR_TE, CR_BUFE, RX_READ_PTR_MASK, CAPR}; +use crate::network::raw_array::WrappingRawArray; + +use crate::{ + interrupts::{InterruptHandler, PICS}, + memory::BootInfoFrameAllocator, + network::{devices, netsync::SafeRTL8139, ethernet::{EthernetPacket, EthType}, arp::ArpPacket, layer::Layer}, + println, +}; +use super::constants::{INTERRUPT_MASK, ROK, TOK, RTL_VEND, RTL_DEV, TRANSMIT_REG, TRANSMIT_CMD, BROADCAST_ADDR}; +use super::{ + arp_table::ArpEntry, + devices::{Device, PCIClassCodes}, + layer::{full_parse, EmptyLayer, PacketData}, + netsync::InterruptCounter, +}; + +// ISR_ROK|ISR_TOK|ISR_RXOVW|ISR_TER|ISR_RER +// ! URGENT: start checking other statuses for errors +// todo: why can I get "send" interrupts... what is the purpose + +// FROM OS DEV +// N.B. If you find your driver suddenly freezes and stops receiving interrupts and you're using kvm/qemu. Try the option -no-kvm-irqchip +static mut TRANSMIT_IDX: u32 = 0; +static mut RECV_POS: u16 = 0; // todo this should be wrapped in a lock? +static mut IO_BASE: usize = 0; +// these static muts are set from None to Some and never really changed + +lazy_static! { + // ! if a process tries to get this lock without disabling interrupts, we would deadlock + // ! The safertl unwraps the locking mechanism with a RAII enable and disable (Well it would if I could get it to work) + pub static ref NET_INFO: SafeRTL8139 = { + let devices = devices::scan_devices(); + SafeRTL8139::new(spin::Mutex::new(RTL8139::new(devices))) + }; +} + +static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = spin::Mutex::new(InterruptCounter { data: 0 }); +// Disable network interrupts (is thread safe) +pub fn disable_network_interrupts() -> () { + let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; + if data.get() == 0 { + let mut port_imr = Port::::new((unsafe { IO_BASE } + 0x3C) as u16); + unsafe { port_imr.write(0x0) }; + } + data.inc(); +} + +// Enable network interrupts (is thread safe) +pub fn enable_network_interrupts() -> () { + let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; + data.dec(); + if data.get() == 0 { + let mut port_imr = Port::::new((unsafe { IO_BASE } + 0x3C) as u16); + unsafe { port_imr.write(INTERRUPT_MASK) }; + } +} + +pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) -> () { + println!("[INTERRUPT] - "); + // Try to get the device info + let mut net_dev = NET_INFO.lock_no_disable(); + if net_dev.is_none() { + panic!("RTL_INFO is undefined!"); + } + // Get the device fields + let rtl_dev_info = net_dev.as_mut().unwrap(); + let io_base = rtl_dev_info.config.io_base; + let irq = rtl_dev_info.config.irq; + if io_base.is_none() || irq.is_none() { + println!("[ERR] Handling packet - missing data"); + unsafe { + PICS.lock().notify_end_of_interrupt(irq.unwrap() + 32); + } + return; + } + + // stop interrupts to the device + let mut port_imr = Port::::new((io_base.unwrap() + 0x3C) as u16); + unsafe { port_imr.write(0x0) }; + + // Read the ISR register + let mut port_isr = Port::::new((io_base.unwrap() + 0x3E) as u16); + let status = unsafe { port_isr.read() }; + // Reset the ISR register + unsafe { port_isr.write(0x05) }; + // println!("!! {} !!", status); + if status & TOK != 0x0 { + // Sent + // println!("Sending packet"); + } + if status & ROK != 0x0 { + // println!("Receiving packet"); + // Received packet + let pkt = recv_packet(&rtl_dev_info); + match pkt { + PacketData::ARP(arp) => { + // todo: also check for broadcast + if arp.recp_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { + // println!("[INT-HANDLER] Send a response back"); + let eth_layer = EthernetPacket::gen(arp.sender_mac.val(), rtl_dev_info.mac_address.unwrap(), EthType::Arp); + let arp_layer = ArpPacket::gen(eth_layer, rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR), arp.sender_ip.val(), false); + let arp_pkt = arp_layer.serialize(); + rtl_dev_info.send_packet(&arp_pkt); + } else { + // println!("[INT-HANDLER] Receiving arp reply"); + // todo: expire from arp table? + rtl_dev_info.arp_table.push(ArpEntry { + mac: arp.sender_mac.val(), + ip: arp.sender_ip.val(), + expires: 0, + }); + } + } + PacketData::DHCP(dhcp) => { + let dst_port = dhcp.udp_packet.dest_port.val(); + // println!("[INT-HANDLER] Found DHCP packet"); + if rtl_dev_info.open_ports.contains(&dst_port) { + // println!("[INT-HANDLER] Port {} is open", dst_port); + // if we are listening on the port, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&dst_port) { + rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + } + rtl_dev_info + .ports + .get_mut(&dst_port) + .unwrap() + .push_back(PacketData::DHCP(dhcp)); + } + } + PacketData::UDP(udp) => { + let dst_port = udp.dest_port.val(); + if rtl_dev_info.open_ports.contains(&dst_port) { + // if we are listening on the port, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&dst_port) { + rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + } + rtl_dev_info + .ports + .get_mut(&dst_port) + .unwrap() + .push_back(PacketData::UDP(udp)); + } + }, + _ => {} // ignore others + } + } + + // Allow interrupts to the device + unsafe { port_imr.write(INTERRUPT_MASK) }; + + // Notify end of interrupt + unsafe { + PICS.lock().notify_end_of_interrupt(irq.unwrap() + 32); + } +} + +// todo: refactor to be a loop, this function needs to return a list +fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { + if rtl_dev_info.recv_buffer.is_none() || rtl_dev_info.physical_mem_offset.is_none() { + panic!("RTL8139 is not initialized properly"); + } + // Make sure buffer isn't empty + let cmd_reg = (rtl_dev_info.config.io_base.unwrap() + CR) as u16; + let mut cmd_port = Port::::new(cmd_reg); + while unsafe { cmd_port.read() } & CR_BUFE == 0x0 { + // Receive a packet by reading the buffer + // ? Reading the buffer is naturally unsafe? Is there a better way? + let virtual_buffer_recv: VirtAddr = VirtAddr::new( + rtl_dev_info.recv_buffer.unwrap().as_u64() + rtl_dev_info.physical_mem_offset.unwrap(), + ); + // todo: check for packet validity https://www.cs.usfca.edu/~cruse/cs326f04/RTL8139_ProgrammersGuide.pdf + let vb_recv_start: *const u8 = virtual_buffer_recv.as_ptr(); + let mut rx_buffer = WrappingRawArray::new(vb_recv_start, RX_BUFFER_SIZE as usize); + rx_buffer.shift_amount(unsafe { RECV_POS } as usize); + let header_buf = rx_buffer.trim(2); + let header = (header_buf[0] as u16) | (header_buf[1] as u16) << 8; + // println!("Header {}", header); + // Checking receive OK and no errors + if header & 0x01 != 0 && header & 0x02 == 0 && header & 0x04 == 0 && header & 0x20 == 0 { + let length_buf = rx_buffer.trim(2); // get the next two bytes + let length = (length_buf[0] as u16) | (length_buf[1] as u16) << 8; + // println!("Length {}", length); + let packet = rx_buffer.trim((length - 4) as usize); + // ? throw out the crc... we don't need to check it... + rx_buffer.shift_amount(4); + let amount_parsed_and_pkt = full_parse(&packet); + + // the amount we parse will be equal to length unless we are under the minimum + assert!(amount_parsed_and_pkt.0 == (length - 4) as usize || length >= 64); + // after receiving the packet, update CAPR and RECV_POS + // increment recv_pos + unsafe { RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; } + unsafe { RECV_POS = (RECV_POS + length) % RX_BUFFER_SIZE; } + // we and with RX_READ_PTR_MASK to ensure double word alignment + // 4 for header length, 3 is also apparently for dword alignment... + unsafe { RECV_POS = ((RECV_POS + 4) & RX_READ_PTR_MASK) % RX_BUFFER_SIZE; } + let mut capr = Port::::new((rtl_dev_info.config.io_base.unwrap() + CAPR) as u16); + // println!("[RECV_POS] {}", unsafe { RECV_POS }); + unsafe { capr.write(RECV_POS - 0x10) }; + return amount_parsed_and_pkt.1; + } else { + unsafe { RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; } + break; + } + } + return PacketData::UNDEF(EmptyLayer::new()); +} + +// TODO: Split the driver into separate bits so we can lock individual resources? +// ! Otherwise we will have a bottleneck? +pub struct RTL8139 { + pub config: Device, + recv_buffer: Option, // 12KB + send_buffer: Option, // 12KB + physical_mem_offset: Option, + pub my_ip_address: Option, + pub dhcp_server_ip: Option, + pub mac_address: Option, + pub open_ports: HashSet, + pub ports: HashMap>, + arp_table: Vec, +} + +impl RTL8139 { + // Initialize the card + pub fn init(&mut self, frame_allocator: &mut BootInfoFrameAllocator, physical_mem_offset: u64) -> bool { + let mut frames: Vec = vec![]; + // create recv and send buffer space + for _ in 0..6 { + let f = frame_allocator.allocate_frame(); + if f.is_none() { + return false; + } + frames.push(f.unwrap()); + } + // Ensure the sections are continous + let mut next_start = frames[0].start_address() + frames[0].size(); + for i in 1..6 { + // if frame isn't continous (can be acceptable on boundary between send and recv buffer) + if (frames[i].start_address() != next_start) && i != 3 { + println!("[ERR] Frames aren't continous {}", i); + return false; + } + next_start = frames[i].start_address() + frames[i].size(); + } + // Initialize data members + println!( + "[INFO] Setting up receive buffer at {}", + frames[0].start_address().as_u64() + ); + println!( + "[INFO] Setting up send buffer at {}", + frames[3].start_address().as_u64() + ); + self.send_buffer = Some(frames[0].start_address()); + self.recv_buffer = Some(frames[3].start_address()); + self.physical_mem_offset = Some(physical_mem_offset); + // call setup to init the card + let setup_status = self.setup(); + match setup_status { + true => println!("[INFO] Set up RTL8139 successful!!"), + false => println!("[ERR] Set up RTL8139 failed!!") + }; + return setup_status; + } + + fn new(configs: Vec) -> Option { + let mut use_dev = None; + for dev in configs.iter() { + if dev.vendor_id == RTL_VEND as u16 + && dev.device_id == RTL_DEV as u16 + && dev.class_code == PCIClassCodes::NetworkController + { + use_dev = Some(dev.clone()); + } + } + if use_dev.is_some() { + // set the io base for enabling and disabling interrupts + unsafe { IO_BASE = use_dev.as_ref().unwrap().io_base.unwrap() as usize }; + // Return the device and a 12KB physical region + return Some(RTL8139 { + config: use_dev.unwrap(), + recv_buffer: None, + send_buffer: None, + my_ip_address: None, + dhcp_server_ip: None, + physical_mem_offset: None, + mac_address: None, + open_ports: HashSet::with_capacity(10), + ports: HashMap::with_capacity(10), + arp_table: Vec::new(), + }); + } + None + } + + // True on success + // Interrupt handler and frame allocation must be done as a pre-req + // Also will set the mac address of the struct... (before then, the mac address is 0) + fn setup(&mut self) -> bool { + // disable the interrupt line to prevent early triggering of an interrupt before we register the handler + InterruptHandler::block_irq(self.config.irq.unwrap()); + + // enable bus mastering + let mut cr = self.config.read_command_register(); + cr.set_bus_master_bit(true); + self.config.write_command_register(cr); + + // Check bus mastering + let cr = self.config.read_command_register(); + match cr.get_bus_master_bit() { + true => println!("[INFO] Bus mastering enabled!"), + false => { + println!("[ERR] Enabling bus mastering failed"); + return false; + } + } + match cr.get_interrupt_disable_bit() { + true => println!("[ERR] Interrupts disabled"), + false => println!("[INFO] Interrupts enabled"), + } + // Turning on the RTL8139 + if self.config.io_base.is_none() { + println!("[ERR] Cannot find IO-Address"); + return false; + } + if self.config.irq.is_none() { + println!("[ERR] Cannot find IRQ"); + return false; + } + + // Register the interrupt handler for the card + IDT.lock() + .register_irq(self.config.irq.unwrap() as usize, network_handle); + + // Get MAC address + let mac_addr = self.config.io_base.unwrap() + 0x00; + let mut mac1 = Port::::new(mac_addr as u16); + let mut mac2 = Port::::new((mac_addr + 0x04) as u16); + let mut mac_addr = 0x0_u64; + unsafe { + mac_addr = mac_addr | mac1.read() as u64; + mac_addr = mac_addr | ((mac2.read() as u64) << 32); + }; + // save mac address + self.mac_address = Some(mac_addr); + println!("[INFO] MAC address is {:#10x}", mac_addr); + + // turn on the card + let addr = self.config.io_base.unwrap() + 0x52; + let mut port_config_1 = Port::::new(addr as u16); + unsafe { port_config_1.write(0x00) }; + + // Performing software reset + let cmd_reg = self.config.io_base.unwrap() + CR; + let mut port_rst = Port::::new(cmd_reg as u16); + unsafe { + port_rst.write(0x10); + while port_rst.read() & 0x10 != 0 {} // spin until we observe reset is over + }; + println!("[INFO] RTL8139 has been reset!"); + + // Enable receiver and transmitter + let mut port_recv_transmit = Port::::new(cmd_reg as u16); + unsafe { port_recv_transmit.write(CR_TE | CR_RE) }; + + // Configuring receive buffer + let rcr_reg = self.config.io_base.unwrap() + 0x44; + let mut rcr = Port::::new(rcr_reg as u16); + unsafe { + let broadcast = 0x08; // Accept broadcast packets sent to mac ff:ff:ff:ff:ff:ff + let multicast = 0x04; // Accept multicast packets + let physical_match = 0x02; // Accept physical matches + let promiscous = 0x01; // Accept all packets + // (1 << 7) is the WRAP bit, 0xf is broadcast, multicast, physical match, accept all packets + rcr.write(physical_match | multicast | broadcast | promiscous); + }; + + // Init receive buffer + let rcv_buf_reg = self.config.io_base.unwrap() + 0x30; + let mut rcv_buffer = Port::::new(rcv_buf_reg as u16); + unsafe { rcv_buffer.write(self.recv_buffer.unwrap().as_u64() as u32) }; + + // Set IMR + ISR + let imr_reg = self.config.io_base.unwrap() + 0x3C; + let mut imr = Port::::new(imr_reg as u16); + unsafe { imr.write(INTERRUPT_MASK) }; + + // Enable interrupts + // How I understood why: https://forum.osdev.org/viewtopic.php?f=1&t=27901 + InterruptHandler::unblock_irq(self.config.irq.unwrap()); + return true; + } + + // todo: remove this from driver code (add to socket.rs) + pub fn get_mac_from_ip(&self, ip: u32) -> u64 { + for entry in self.arp_table.iter() { + // todo: check for expired arps + if entry.ip == ip { + return entry.mac; + } + } + // send arp packet + + // wait for response + // recursively try again + return 0; + } + + pub fn send_packet(&self, packet_data: &Vec) -> () { + if self.send_buffer.is_none() || self.physical_mem_offset.is_none() { + panic!("RTL8139 is not initialized properly"); + } + // If we don't have a mac address, stop + let io_base = self.config.io_base; + if self.mac_address.is_none() || io_base.is_none() { + return; + } + + let virtual_buffer: VirtAddr = + VirtAddr::new(self.send_buffer.unwrap().as_u64() + self.physical_mem_offset.unwrap()); + let virtual_buffer_ptr: *mut u8 = virtual_buffer.as_mut_ptr(); + // Copy data into buffer + for i in 0..packet_data.len() { + unsafe { *(virtual_buffer_ptr.wrapping_add(i)) = packet_data[i] }; + } + + // TODO: Make this part of self... + let reg = TRANSMIT_REG[unsafe { TRANSMIT_IDX as usize }]; + let cmd = TRANSMIT_CMD[unsafe { TRANSMIT_IDX as usize }]; + + let mut reg_port = Port::::new((io_base.unwrap() + reg) as u16); + let mut cmd_port = Port::::new((io_base.unwrap() + cmd) as u16); + unsafe { + reg_port.write(virtual_buffer.as_u64() as u32); + cmd_port.write(packet_data.len() as u32); + } + // Send the packet from the buffer + unsafe { + TRANSMIT_IDX += 1; + TRANSMIT_IDX = TRANSMIT_IDX % 4; + }; + } +} + +// sudo qemu-system-x86_64 -M q35 -serial mon:stdio -nographic -netdev vmnet-bridged,id=net0,ifname=en0 -device rtl8139,netdev=net0,mac=00:11:22:33:44:55 +// \ No newline at end of file diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs new file mode 100644 index 0000000..326c63b --- /dev/null +++ b/kernel/src/network/socket.rs @@ -0,0 +1,94 @@ +use crate::println; + +use super::{rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}, layer::PacketData}; + +#[derive(Debug)] +pub enum NetworkErrors { + PortInUse, +} + +// todo: Implement a socket for user-space... + +pub struct RawSocket { + port: u16 +} + +impl RawSocket { + pub fn new(port: u16) -> Result { + disable_network_interrupts(); + let mut rtl_dev_info_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // Check if the port is in use + if rtl_dev_info.open_ports.contains(&port) { + enable_network_interrupts(); + return Err(NetworkErrors::PortInUse); + } + // If not then bind to it + rtl_dev_info.open_ports.insert(port); + enable_network_interrupts(); + Ok(RawSocket { port }) + } + + fn try_get_packet_inner(&self) -> Option { + let mut rtl_dev_info_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + match rtl_dev_info.ports.get_mut(&self.port) { + Some(vec) => { + return vec.pop_front(); + }, + None => return None, + } + } + + // try to get a packet. Won't block, but will return Option + pub fn try_get_packet(&self) -> Option { + disable_network_interrupts(); + let pkt = self.try_get_packet_inner(); + enable_network_interrupts(); + return pkt; + } + + // Query a port for a packet. Will block until a packet arrives + pub fn get_packet(&self) -> PacketData { + let pkt; + loop { + x86_64::instructions::hlt(); + match self.try_get_packet() { + Some(next_pkt) => { + pkt = next_pkt; + break; + }, + None => {}, + } + } + return pkt; + } + + // Query a port for a packet. Will block until a packet arrives + pub fn get_packet_with_timeout(&self, timeout_s: u32) -> Option { + let mut pkt = None; + for _ in 0..(18 * timeout_s) { + println!("One"); + pkt = self.try_get_packet(); + if pkt.is_some() { break; } + println!("Here?"); + x86_64::instructions::hlt(); + println!("Two?"); + } + return pkt; + } + + pub fn close(&mut self) -> () { + disable_network_interrupts(); + let mut rtl_dev_info_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // close the port so that we don't receive anymore packets + rtl_dev_info.open_ports.remove(&self.port); + // Try to clear all the pending packets from the port + if rtl_dev_info.ports.contains_key(&self.port) { + let vec = rtl_dev_info.ports.get_mut(&self.port); + vec.unwrap().clear(); + } + enable_network_interrupts(); + } +} \ No newline at end of file diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs new file mode 100644 index 0000000..e69de29 diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs new file mode 100644 index 0000000..a1db6bd --- /dev/null +++ b/kernel/src/network/udp.rs @@ -0,0 +1,141 @@ +use super::{ + bytefield::Bytefield16, + ip::IPPacket, + layer::{HasChecksum, Layer, LayerType}, +}; + +use alloc::vec; +use alloc::vec::Vec; + +#[derive(Debug)] +pub struct UDPPacket { + pub ip_packet: IPPacket, // public for checksumming + pub src_port: Bytefield16, // 2 bytes + pub dest_port: Bytefield16, // 2 bytes + pub length: Bytefield16, // 2 bytes + checksum: Bytefield16, // 2 bytes + pub data: Vec, // a vector for data bytes if needed + // 10 bytes total +} + +impl UDPPacket { + pub fn new() -> Self { + UDPPacket { + ip_packet: IPPacket::new(), + src_port: Bytefield16::new(0), + dest_port: Bytefield16::new(0), + length: Bytefield16::new(0), + checksum: Bytefield16::new(0), + data: Vec::new(), + } + } + + pub fn gen(ip_packet: IPPacket, src_port: u16, dest_port: u16, length: u16) -> Self { + UDPPacket { + ip_packet, + src_port: Bytefield16::new(src_port), + dest_port: Bytefield16::new(dest_port), + length: Bytefield16::new(length + 8), // size of body + 8 bytes for UDP + checksum: Bytefield16::new(0), + data: Vec::new(), + } + } +} + +impl Layer for UDPPacket { + type Input = IPPacket; + fn parse(ip_layer: IPPacket, bytevec: &[u8]) -> (Self, usize, LayerType) + where + Self: Sized, + { + let mut packet = UDPPacket::new(); // create an empty packet + // Read 14 bytes + let mut i = 0; + packet.ip_packet = ip_layer; + packet.src_port = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.dest_port = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.length = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); + assert!(i == 8); // 8 bytes + let layer_type = match packet.dest_port.val() { + 68 => LayerType::DHCP, + _ => { + // read remaining bytes and place them into the data buffer + for _ in 0..(packet.length.val() - 8) { + packet.data.push(bytevec[i]); + i += 1; + } + assert!(i == packet.length.val() as usize); + LayerType::UNDEF + }, + }; + return (packet, i, layer_type); + } + + fn serialize(&self) -> alloc::vec::Vec { + let mut res = vec![]; + res.extend(self.ip_packet.serialize()); + res.extend(self.src_port.data); + res.extend(self.dest_port.data); + res.extend(self.length.data); + res.extend(self.checksum.data); + res.extend(&self.data); // most of the time should be empty + assert!(res.len() == (8 + self.ip_packet.serialize().len() + self.data.len())); + res + } + + fn packet_size() -> u16 { + 8 + } +} + +impl HasChecksum for UDPPacket { + fn calculate_checksum(&mut self, data: &[u8]) -> () { + // Starting vars + let mut sum: u32 = 0; + self.length.swap_endianness(); + let mut udp_len = self.length.val() as usize; + self.length.swap_endianness(); // swap back because we don't want to permanently mutate the length when calculating... + + // First we do the IP as a pseduo header + let ip = &self.ip_packet; + sum += (ip.source_ip.data[0] as u32) | (ip.source_ip.data[1] as u32) << 8; + sum += (ip.source_ip.data[2] as u32) | (ip.source_ip.data[3] as u32) << 8; + sum += (ip.destination_ip.data[0] as u32) | (ip.destination_ip.data[1] as u32) << 8; + sum += (ip.destination_ip.data[2] as u32) | (ip.destination_ip.data[3] as u32) << 8; + + // Sum protocol and length + let protocol = Bytefield16::new(ip.protocol.as_byte() as u16); + sum += (protocol.data[0] as u32) | (protocol.data[1] as u32) << 8; + sum += (self.length.data[0] as u32) | (self.length.data[1] as u32) << 8; + + // Sum the body + self.checksum = Bytefield16::new(0); + let mut ptr = 0; + while udp_len > 1 { + sum += (data[ptr] as u32) | ((data[ptr + 1] as u32) << 8); + udp_len -= 2; + ptr += 2; + } + + if data.len() % 2 == 1 { + // Add the padding if the packet length is odd + sum += data[ptr] as u32; + } + + // Add the carries + while sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16); + } + + // One's complement + let mut res = !sum as u16; + // Swap the bytes because we did our sum in big endian + // (and the bytefield will try to convert to big endian) + res = ((res >> 8) & 0xFF) | ((res & 0xFF) << 8); + + // Return the one's complement of sum + self.checksum = Bytefield16::new(res); + } +} + diff --git a/src/main.rs b/src/main.rs index 7572840..fdc19f5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,17 @@ fn main() { let uefi = true; let mut cmd = std::process::Command::new("qemu-system-x86_64"); + // ---- networking related arguments ---- // + // related to issue with kernel driver freezing. see https://wiki.osdev.org/RTL8139 + cmd.arg("-machine").arg("kernel_irqchip=off"); + // Virtualizing a network, has an entry point to inject packets into the isolated network + // Can send UDP from localhost:5555 to SOUP:5554 + cmd.arg("-netdev").arg("user,id=net0,hostfwd=udp::5555-:5554"); + // Making sure we have the rtl8139 as a hardware resource + cmd.arg("-device").arg("rtl8139,netdev=net0,mac=00:11:22:33:44:55"); + // Recording the network traffic in order to conduct debugging and analysis + cmd.arg("-object").arg("filter-dump,id=f1,netdev=net0,file=/tmp/dump.pcap"); + if uefi { cmd.arg("-bios").arg(ovmf_prebuilt::ovmf_pure_efi()); cmd.arg("-drive") From 3249ca61d84c6a1a23d0352eac4e79ba4563c6dc Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 31 Oct 2023 15:14:29 -0400 Subject: [PATCH 02/36] Fixing clippy and addressing comments --- kernel/src/interrupts.rs | 26 +-- kernel/src/main.rs | 8 +- kernel/src/network/arp.rs | 2 +- kernel/src/network/bytefield.rs | 52 +++--- kernel/src/network/command_register.rs | 58 +++---- kernel/src/network/devices.rs | 24 +-- kernel/src/network/dhcp.rs | 4 +- kernel/src/network/e1000.rs | 211 ------------------------- kernel/src/network/ethernet.rs | 2 +- kernel/src/network/init.rs | 4 +- kernel/src/network/ip.rs | 16 +- kernel/src/network/layer.rs | 4 +- kernel/src/network/netsync.rs | 10 +- kernel/src/network/raw_array.rs | 4 +- kernel/src/network/rtl8139.rs | 44 +++--- kernel/src/network/socket.rs | 26 ++- kernel/src/network/udp.rs | 12 +- src/main.rs | 2 +- 18 files changed, 140 insertions(+), 369 deletions(-) delete mode 100644 kernel/src/network/e1000.rs diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index 18f7334..c65d9f1 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -1,6 +1,6 @@ use lazy_static::lazy_static; use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame}; -use crate::{gdt, println, print}; +use crate::{gdt, println}; use pic8259::ChainedPics; use spin; use x86_64::structures::idt::PageFaultErrorCode; @@ -49,33 +49,35 @@ impl InterruptHandler { InterruptHandler { idt } } - pub fn init(&self) -> (){ + pub fn init(&self){ unsafe { self.idt.load_unsafe() }; } // Static function for disabling an irq - pub fn unblock_irq(irq_num: u8) -> () { - let data = unsafe { PICS.lock().read_masks() }; + pub fn unblock_irq(irq_num: u8) { + let mut locked_pics = PICS.lock(); + let data = unsafe { locked_pics.read_masks() }; // set the irq bit to 0 if irq_num < 8 { - unsafe { PICS.lock().write_masks(data[0] & !(1 << irq_num), data[1]) }; + unsafe { locked_pics.write_masks(data[0] & !(1 << irq_num), data[1]) }; } else { - unsafe { PICS.lock().write_masks(data[0], data[1] & !(1 << irq_num - 8)) }; + unsafe { locked_pics.write_masks(data[0], data[1] & !(1 << (irq_num - 8))) }; } } // Static function for re-enabling an IRQ - pub fn block_irq(irq_num: u8) -> () { - let data = unsafe { PICS.lock().read_masks() }; + pub fn block_irq(irq_num: u8) { + let mut locked_pics = PICS.lock(); + let data = unsafe { locked_pics.read_masks() }; // set the irq bit to 1 if irq_num < 8 { - unsafe { PICS.lock().write_masks(data[0] | 1 << irq_num, data[1]) }; + unsafe { locked_pics.write_masks(data[0] | 1 << irq_num, data[1]) }; } else { - unsafe { PICS.lock().write_masks(data[0], data[1] | 1 << irq_num - 8) }; + unsafe { locked_pics.write_masks(data[0], data[1] | 1 << (irq_num - 8)) }; } } - pub fn register_irq(&mut self, irq_num: usize, handler: InterruptHandlerFunc) -> (){ + pub fn register_irq(&mut self, irq_num: usize, handler: InterruptHandlerFunc){ println!("Registered Handler @ {}", irq_num + 32); self.idt[irq_num + 32].set_handler_fn(handler); unsafe { self.idt.load_unsafe() }; @@ -118,8 +120,6 @@ extern "x86-interrupt" fn double_fault_handler( } extern "x86-interrupt" fn timer_interrupt_handler(_stack_frame: InterruptStackFrame) { - print!("."); - unsafe { PICS.lock() .notify_end_of_interrupt(InterruptIndex::Timer.as_u8()); diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 4a2a957..0de8cc6 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -13,7 +13,7 @@ use core::panic::PanicInfo; use kernel::{ framebuffer, hlt_loop, network::{ethernet::{EthernetPacket, self}, udp::UDPPacket, ip::{IPPacket, Protocol}, layer::{Layer, HasChecksum, LayerType}, rtl8139::{disable_network_interrupts, NET_INFO, enable_network_interrupts}, socket::RawSocket, init::init_dhcp}, - println, + println, print, task::keyboard, task::{executor::Executor, Task}, }; @@ -76,9 +76,9 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { rtl_driver.unwrap().init(&mut frame_allocator, phys_mem_offset) }; // so that the NET INFO gets released let status_init_dhcp = init_dhcp(2); - if !status_init && false { + if !status_init { println!("[ERR] Cannot init RTL8139"); - } else if !status_init_dhcp && false { + } else if !status_init_dhcp { println!("[ERR] DHCP error -- whats my ip?"); } else { let raw_socket = RawSocket::new(5554); @@ -92,7 +92,7 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { let data_cloned_len = data_cloned.len(); let user_message = String::from_utf8(udp_pkt.data); match user_message { - Ok(message) => println!("[USER] {}", message), + Ok(message) => print!("[USER] {}", message), Err(err) => println!("[USER-ERR] {:?}", err), } diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs index c1cfab5..e287087 100644 --- a/kernel/src/network/arp.rs +++ b/kernel/src/network/arp.rs @@ -74,7 +74,7 @@ impl Layer for ArpPacket { packet.recp_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); assert!(i == 28); // Arp packet should be 28 bytes - return (packet, i, LayerType::UNDEF); + (packet, i, LayerType::UNDEF) } fn serialize(&self) -> Vec { diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs index f2581e7..624d9f4 100644 --- a/kernel/src/network/bytefield.rs +++ b/kernel/src/network/bytefield.rs @@ -29,13 +29,13 @@ impl Bytefield128 { pub fn val(&self) -> u128 { let mut res = 0_u128; for i in 0..Self::size() { - res = res | ((self.data[i] as u128) << (i * 8)); + res |= (self.data[i] as u128) << (i * 8); } - return res; + res } // Swap the endianness of the data - pub fn swap_endianness(&mut self) -> () { + pub fn swap_endianness(&mut self) { self.data.reverse(); } @@ -72,13 +72,13 @@ impl Bytefield64 { pub fn val(&self) -> u64 { let mut res = 0_u64; for i in 0..Self::size() { - res = res | ((self.data[i] as u64) << (i * 8)); + res |= (self.data[i] as u64) << (i * 8); } - return res; + res } // Swap the endianness of the data - pub fn swap_endianness(&mut self) -> () { + pub fn swap_endianness(&mut self) { self.data.reverse(); } @@ -115,13 +115,13 @@ impl Bytefield48 { pub fn val(&self) -> u64 { let mut res = 0_u64; for i in 0..Self::size() { - res = res | ((self.data[i] as u64) << (i * 8)); + res |= (self.data[i] as u64) << (i * 8); } - return res; + res } // Swap the endianness of the data - pub fn swap_endianness(&mut self) -> () { + pub fn swap_endianness(&mut self) { self.data.reverse(); } @@ -158,13 +158,13 @@ impl Bytefield32 { pub fn val(&self) -> u32 { let mut res = 0_u32; for i in 0..Self::size() { - res = res | ((self.data[i] as u32) << (i * 8)); + res |= (self.data[i] as u32) << (i * 8); } - return res; + res } // Swap the endianness of the data - pub fn swap_endianness(&mut self) -> () { + pub fn swap_endianness(&mut self) { self.data.reverse(); } @@ -182,7 +182,7 @@ pub struct Bytefield16 { impl Bytefield16 { // Create a bytefield and swap endian-ness pub fn new(val: u16) -> Self { - Self { data: [(val >> 1 * 8 & 0xFF) as u8, (val >> 0 * 8 & 0xFF) as u8] } + Self { data: [(val >> 8 & 0xFF) as u8, (val & 0xFF) as u8] } } pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { @@ -194,13 +194,13 @@ impl Bytefield16 { pub fn val(&self) -> u16 { let mut res = 0_u16; for i in 0..Self::size() { - res = res | ((self.data[i] as u16) << (i * 8)); + res |= (self.data[i] as u16) << (i * 8); } - return res; + res } // Swap the endianness of the data - pub fn swap_endianness(&mut self) -> () { + pub fn swap_endianness(&mut self) { self.data.reverse(); } @@ -240,65 +240,65 @@ impl Bytefield8 { // ===== DEFINING INDEXING OPERATIONS ====== // impl Index for Bytefield128 { type Output = u8; - fn index<'a>(&'a self, i: usize) -> &'a u8 { + fn index(&self, i: usize) -> &u8 { &self.data[i] } } impl IndexMut for Bytefield128 { - fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + fn index_mut(&mut self, i: usize) -> &mut u8 { &mut self.data[i] } } impl Index for Bytefield64 { type Output = u8; - fn index<'a>(&'a self, i: usize) -> &'a u8 { + fn index(&self, i: usize) -> &u8 { &self.data[i] } } impl IndexMut for Bytefield64 { - fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + fn index_mut(&mut self, i: usize) -> &mut u8 { &mut self.data[i] } } impl Index for Bytefield48 { type Output = u8; - fn index<'a>(&'a self, i: usize) -> &'a u8 { + fn index(&self, i: usize) -> &u8 { &self.data[i] } } impl IndexMut for Bytefield48 { - fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + fn index_mut(&mut self, i: usize) -> &mut u8 { &mut self.data[i] } } impl Index for Bytefield32 { type Output = u8; - fn index<'a>(&'a self, i: usize) -> &'a u8 { + fn index(&self, i: usize) -> &u8 { &self.data[i] } } impl IndexMut for Bytefield32 { - fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + fn index_mut(&mut self, i: usize) -> &mut u8 { &mut self.data[i] } } impl Index for Bytefield16 { type Output = u8; - fn index<'a>(&'a self, i: usize) -> &'a u8 { + fn index(&self, i: usize) -> &u8 { &self.data[i] } } impl IndexMut for Bytefield16 { - fn index_mut<'a>(&'a mut self, i: usize) -> &'a mut u8 { + fn index_mut(&mut self, i: usize) -> &mut u8 { &mut self.data[i] } } \ No newline at end of file diff --git a/kernel/src/network/command_register.rs b/kernel/src/network/command_register.rs index acc1249..5e29067 100644 --- a/kernel/src/network/command_register.rs +++ b/kernel/src/network/command_register.rs @@ -10,92 +10,92 @@ impl CommandRegister { // basic getter for internal data pub fn data(&self) -> u16 { - return self.cr; + self.cr } // 0th bit pub fn get_io_space_bit(&self) -> bool { - return (self.cr & 0x1) != 0; + (self.cr & 0x1) != 0 } - pub fn set_io_space_bit(&mut self, is_on: bool) -> (){ + pub fn set_io_space_bit(&mut self, is_on: bool){ match is_on { - true => self.cr = self.cr | 0x1, - false => self.cr = self.cr & !0x1, + true => self.cr |= 0x1, + false => self.cr &= !0x1, } } // 1st bit pub fn get_memory_space_bit(&self) -> bool { - return (self.cr & 0x2) != 0; + (self.cr & 0x2) != 0 } - pub fn set_memory_space_bit(&mut self, is_on: bool) -> (){ + pub fn set_memory_space_bit(&mut self, is_on: bool){ match is_on { - true => self.cr = self.cr | 0x2, - false => self.cr = self.cr & !0x2, + true => self.cr |= 0x2, + false => self.cr &= !0x2, } } // 2nd bit pub fn get_bus_master_bit(&self) -> bool { - return (self.cr & 0x4) != 0; + (self.cr & 0x4) != 0 } - pub fn set_bus_master_bit(&mut self, is_on: bool) -> (){ + pub fn set_bus_master_bit(&mut self, is_on: bool){ match is_on { - true => self.cr = self.cr | 0x4, - false => self.cr = self.cr & !0x4, + true => self.cr |= 0x4, + false => self.cr &= !0x4, } } // 3rd bit pub fn get_special_cycles_bit(&self) -> bool { - return (self.cr & 0x8) != 0; + (self.cr & 0x8) != 0 } // 4th bit pub fn get_memory_write_invalidate_enable_bit(&self) -> bool { - return (self.cr & 0x10) != 0; + (self.cr & 0x10) != 0 } // 5th bit pub fn get_vga_palette_snoop_bit(&self) -> bool { - return (self.cr & 0x20) != 0; + (self.cr & 0x20) != 0 } // 6th bit pub fn get_parity_err_res_bit(&self) -> bool { - return (self.cr & 0x40) != 0; + (self.cr & 0x40) != 0 } - pub fn set_parity_err_res_bit(&mut self, is_on: bool) -> (){ + pub fn set_parity_err_res_bit(&mut self, is_on: bool){ match is_on { - true => self.cr = self.cr | 0x40, - false => self.cr = self.cr & !0x40, + true => self.cr |= 0x40, + false => self.cr &= !0x40, } } // 8th bit pub fn get_serr_enable_bit(&self) -> bool { - return (self.cr & 0x100) != 0; + (self.cr & 0x100) != 0 } - pub fn set_serr_enable_bit(&mut self, is_on: bool) -> (){ + pub fn set_serr_enable_bit(&mut self, is_on: bool){ match is_on { - true => self.cr = self.cr | 0x100, - false => self.cr = self.cr & !0x100, + true => self.cr |= 0x100, + false => self.cr &= !0x100, } } // 9th bit pub fn get_fast_back_to_back_enable_bit(&self) -> bool { - return (self.cr & 0x200) != 0; + (self.cr & 0x200) != 0 } // 10th bit pub fn get_interrupt_disable_bit(&self) -> bool { - return (self.cr & 0x400) != 0; + (self.cr & 0x400) != 0 } - pub fn set_interrupt_disable_bit(&mut self, is_on: bool) -> (){ + pub fn set_interrupt_disable_bit(&mut self, is_on: bool){ match is_on { - true => self.cr = self.cr | 0x400, - false => self.cr = self.cr & !0x400, + true => self.cr |= 0x400, + false => self.cr &= !0x400, } } } \ No newline at end of file diff --git a/kernel/src/network/devices.rs b/kernel/src/network/devices.rs index d5a61b6..9dac963 100644 --- a/kernel/src/network/devices.rs +++ b/kernel/src/network/devices.rs @@ -81,7 +81,7 @@ fn create_confg_address(bus: u8, slot: u8, func: u8, offset: u8){ let lslot = slot as u32; let lfunc = func as u32; - let address = ((lbus << 16) | (lslot << 11) | (lfunc << 8) | ((offset as u32) & 0xFC) | (0x80000000 as u32)) as u32; + let address = (lbus << 16) | (lslot << 11) | (lfunc << 8) | ((offset as u32) & 0xFC) | 0x80000000_u32; let mut port = Port::::new(CONFIG_ADDRESS); // Write the address unsafe { port.write(address) }; @@ -95,13 +95,13 @@ fn pci_config_read_dword(bus: u8, slot: u8, func: u8, offset: u8) -> u32 { let mut port = Port::::new(CONFIG_DATA); // Read the data let data: u32 = unsafe { port.read() }; - return data; + data } /// Query PCI to write 16 bits of data of a device defined by (bus, slot) /// Will offset into the PCI configuration header by offset // Func is used for multi-function devices... (don't think that will be used with network card) -fn pci_config_write_word(bus: u8, slot: u8, func: u8, offset: u8, word: u16) -> () { +fn pci_config_write_word(bus: u8, slot: u8, func: u8, offset: u8, word: u16) { create_confg_address(bus, slot, func, offset); let mut port = Port::::new(CONFIG_DATA); // Read the data @@ -120,14 +120,14 @@ fn pci_check_vendor(bus: u8, slot: u8) -> Option { if vendor != 0xFFFF { return Some(vendor); } - return None; + None } // Assumes a device at (bus, slot) // Will extract the class code from the configuration space fn pci_get_device_id(bus: u8, slot: u8) -> u16 { let device_id = (pci_config_read_dword(bus, slot, 0, 0) >> 16) & 0xFFFF; - return device_id as u16; + device_id as u16 } // Assumes a device at (bus, slot) @@ -138,7 +138,7 @@ fn pci_get_class_code(bus: u8, slot: u8) -> (PCIClassCodes, u8) { // let _progif = ((code >> 8) & 0xFF) as u8; let subclass = ((code >> 16) & 0xFF) as u8; let class = ((code >> 24) & 0xFF) as u8; - return (PCIClassCodes::from(class), subclass); + (PCIClassCodes::from(class), subclass) } // Read the interrupt line of the PCI Configuration address space @@ -154,11 +154,11 @@ fn pci_get_irq(bus: u8, slot: u8) -> Option { fn pci_get_cmd_reg(bus: u8, slot: u8) -> CommandRegister { let cr = pci_config_read_dword(bus, slot, 0, 0x4); - return CommandRegister::new((cr & 0xFFFF) as u16); + CommandRegister::new((cr & 0xFFFF) as u16) } // Set the command register -fn pci_set_cmd_reg(bus: u8, slot: u8, cr: CommandRegister) -> (){ +fn pci_set_cmd_reg(bus: u8, slot: u8, cr: CommandRegister) { // Write the register with cr pci_config_write_word(bus, slot, 0, 0x4, cr.data()); } @@ -180,17 +180,17 @@ fn pci_get_io_base(bus: u8, slot: u8) -> Option { } // Restore original configuration pci_set_cmd_reg(bus, slot, org); - return None + None } impl Device { // Read the command register of the device pub fn read_command_register(&self) -> CommandRegister { - return pci_get_cmd_reg(self.bus, self.slot); + pci_get_cmd_reg(self.bus, self.slot) } // Write to the command register of the device - pub fn write_command_register(&self, new_val: CommandRegister) -> () { + pub fn write_command_register(&self, new_val: CommandRegister) { pci_set_cmd_reg(self.bus, self.slot, new_val); } } @@ -229,5 +229,5 @@ pub fn scan_devices() -> Vec { }); } - return results; + results } \ No newline at end of file diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 2e1e903..129f762 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -99,7 +99,7 @@ impl DHCPPacket { let start_ip = start_udp - (IPPacket::packet_size() as usize); dhcp.udp_packet.ip_packet.calculate_checksum(&data[start_ip..start_udp]); dhcp.udp_packet.calculate_checksum(&data[start_udp..]); - return dhcp; + dhcp } } @@ -131,7 +131,7 @@ impl Layer for DHCPPacket { let left_to_parse = packet.udp_packet.length.val() - 308; i += left_to_parse as usize; assert!(i >= 300); // 300 bytes - return (packet, i, LayerType::UNDEF); + (packet, i, LayerType::UNDEF) } fn serialize(&self) -> alloc::vec::Vec { diff --git a/kernel/src/network/e1000.rs b/kernel/src/network/e1000.rs deleted file mode 100644 index 08528d1..0000000 --- a/kernel/src/network/e1000.rs +++ /dev/null @@ -1,211 +0,0 @@ -use alloc::boxed::Box; - -use super::devices::Device; - -const INTEL_VEND: u32 = 0x8086; // Vendor ID for Intel -const E1000_DEV: u32 = 0x100E; // Device ID for the e1000 Qemu, Bochs, and VirtualBox emmulated NICs -const E1000_I217: u32 = 0x153A; // Device ID for Intel I217 -const E1000_82577LM: u32 = 0x10EA; // Device ID for Intel 82577LM - -// Constants From https://wiki.osdev.org/Intel_Ethernet_i217 -const REG_CTRL: u32 = 0x0000; -const REG_STATUS: u32 = 0x0008; -const REG_EEPROM: u32 = 0x0014; -const REG_CTRL_EXT: u32 = 0x0018; -const REG_IMASK: u32 = 0x00D0; -const REG_RCTRL: u32 = 0x0100; -const REG_RXDESCLO: u32 = 0x2800; -const REG_RXDESCHI: u32 = 0x2804; -const REG_RXDESCLEN: u32 = 0x2808; -const REG_RXDESCHEAD: u32 = 0x2810; -const REG_RXDESCTAIL: u32 = 0x2818; - -const REG_TCTRL: u32 = 0x0400; -const REG_TXDESCLO: u32 = 0x3800; -const REG_TXDESCHI: u32 = 0x3804; -const REG_TXDESCLEN: u32 = 0x3808; -const REG_TXDESCHEAD: u32 = 0x3810; -const REG_TXDESCTAIL: u32 = 0x3818; - -const REG_RDTR: u32 = 0x2820; // RX Delay Timer Register -const REG_RXDCTL: u32 = 0x2828; // RX Descriptor Control -const REG_RADV: u32 = 0x282C; // RX Int. Absolute Delay Timer -const REG_RSRPD: u32 = 0x2C00; // RX Small Packet Detect Interrupt - -const REG_TIPG: u32 = 0x0410; // Transmit Inter Packet Gap -const ECTRL_SLU: u32 = 0x40; //set link up - -const RCTL_EN: u32 = 1 << 1; // Receiver Enable -const RCTL_SBP: u32 = 1 << 2; // Store Bad Packets -const RCTL_UPE: u32 = 1 << 3; // Unicast Promiscuous Enabled -const RCTL_MPE: u32 = 1 << 4; // Multicast Promiscuous Enabled -const RCTL_LPE: u32 = 1 << 5; // Long Packet Reception Enable -const RCTL_LBM_NONE: u32 = 0 << 6; // No Loopback -const RCTL_LBM_PHY: u32 = 3 << 6; // PHY or external SerDesc loopback -const RTCL_RDMTS_HALF: u32 = 0 << 8; // Free Buffer Threshold is 1/2 of RDLEN -const RTCL_RDMTS_QUARTER: u32 = 1 << 8; // Free Buffer Threshold is 1/4 of RDLEN -const RTCL_RDMTS_EIGHTH: u32 = 2 << 8; // Free Buffer Threshold is 1/8 of RDLEN -const RCTL_MO_36: u32 = 0 << 12; // Multicast Offset - bits 47:36 -const RCTL_MO_35: u32 = 1 << 12; // Multicast Offset - bits 46:35 -const RCTL_MO_34: u32 = 2 << 12; // Multicast Offset - bits 45:34 -const RCTL_MO_32: u32 = 3 << 12; // Multicast Offset - bits 43:32 -const RCTL_BAM: u32 = 1 << 15; // Broadcast Accept Mode -const RCTL_VFE: u32 = 1 << 18; // VLAN Filter Enable -const RCTL_CFIEN: u32 = 1 << 19; // Canonical Form Indicator Enable -const RCTL_CFI: u32 = 1 << 20; // Canonical Form Indicator Bit Value -const RCTL_DPF: u32 = 1 << 22; // Discard Pause Frames -const RCTL_PMCF: u32 = 1 << 23; // Pass MAC Control Frames -const RCTL_SECRC: u32 = 1 << 26; // Strip Ethernet CRC - -// Buffer Sizes -const RCTL_BSIZE_256: u32 = 3 << 16; -const RCTL_BSIZE_512: u32 = 2 << 16; -const RCTL_BSIZE_1024: u32 = 1 << 16; -const RCTL_BSIZE_2048: u32 = 0 << 16; -const RCTL_BSIZE_4096: u32 = (3 << 16) | (1 << 25); -const RCTL_BSIZE_8192: u32 = (2 << 16) | (1 << 25); -const RCTL_BSIZE_16384: u32 = (1 << 16) | (1 << 25); - -// Transmit Command -const CMD_EOP: u32 = 1 << 0; // End of Packet -const CMD_IFCS: u32 = 1 << 1; // Insert FCS -const CMD_IC: u32 = 1 << 2; // Insert Checksum -const CMD_RS: u32 = 1 << 3; // Report Status -const CMD_RPS: u32 = 1 << 4; // Report Packet Sent -const CMD_VLE: u32 = 1 << 6; // VLAN Packet Enable -const CMD_IDE: u32 = 1 << 7; // Interrupt Delay Enable - -// TCTL Register -const TCTL_EN: u32 = 1 << 1; // Transmit Enable -const TCTL_PSP: u32 = 1 << 3; // Pad Short Packets -const TCTL_CT_SHIFT: u32 = 4; // Collision Threshold -const TCTL_COLD_SHIFT: u32 = 12; // Collision Distance -const TCTL_SWXOFF: u32 = 1 << 22; // Software XOFF Transmission -const TCTL_RTLC: u32 = 1 << 24; // Re-transmit on Late Collision - -const TSTA_DD: u32 = 1 << 0; // Descriptor Done -const TSTA_EC: u32 = 1 << 1; // Excess Collisions -const TSTA_LC: u32 = 1 << 2; // Late Collision -const LSTA_TU: u32 = 1 << 3; // Transmit Underrun - -#[repr(C)] -struct e1000_rx_desc { - volatile uint64_t addr; - volatile uint16_t length; - volatile uint16_t checksum; - volatile uint8_t status; - volatile uint8_t errors; - volatile uint16_t special; -} __attribute__((packed)); - -#[repr(C)] -struct e1000_tx_desc { - volatile uint64_t addr; - volatile uint16_t length; - volatile uint8_t cso; - volatile uint8_t cmd; - volatile uint8_t status; - volatile uint8_t css; - volatile uint16_t special; -} __attribute__((packed)); - - -struct E1000 { - // Type of BAR0 - bar_type: u8, - - // IO Base Address - io_base: u16, - - // MMIO Base Address - mem_base: u64, - - // A flag indicating if eeprom exists - eerprom_exists: bool, - - // A buffer for storing the MAC address - mac: [u8; 6], - - struct e1000_rx_desc *rx_descs[E1000_NUM_RX_DESC]; // Receive Descriptor Buffers - struct e1000_tx_desc *tx_descs[E1000_NUM_TX_DESC]; // Transmit Descriptor Buffers - - // Current Receive Descriptor Buffer - rx_cur: u16, - - // Current Transmit Descriptor Buffer - tx_cur: u16, -} - -impl E1000 { - // Constructor. takes as a parameter a pointer to an object that encapsulate all he PCI configuration data of the device - pub fn new(pci_config: Device) -> E1000 { - unimplemented!("Unimplemented"); - } - - // Perform initialization tasks and starts the driver - pub fn init() -> () { - unimplemented!("Unimplemented"); - } - - // This method should be called by the interrupt handler - pub fn fire(p_interruptContext: InterruptContext) -> () { - unimplemented!("Unimplemented"); - } - - // Returns the MAC address - pub fn get_mac_address() -> [u8; 6]{ - unimplemented!("Unimplemented"); - } - - pub fn send_packet(p_data: Box<()>, p_len: u16) -> u32 { - unimplemented!("Unimplemented"); - } - - // Send Commands and read results From NICs either using MMIO or IO Ports - fn write_command(p_address: u16, p_value: u32) -> () { - unimplemented!("Unimplemented"); - } - fn send_command(p_address: u16) -> () { - unimplemented!("Unimplemented"); - } - - // Detect if EE Prom exists - fn detect_ee_prom() -> bool { - unimplemented!("Unimplemented"); - } - - // Read 4 bytes from a specific EEProm Address - fn eeprom_read(addr: u8) -> u32 { - unimplemented!("Unimplemented"); - } - - // Read MAC Address - fn read_mac_address() -> () { - unimplemented!("Unimplemented"); - } - - // Start up the network - fn start_link() -> () { - unimplemented!("Unimplemented"); - } - - // Initialize receive descriptors an buffers - fn rxinit() -> () { - unimplemented!("Unimplemented"); - } - - // Initialize transmit descriptors an buffers - fn txinit() -> () { - unimplemented!("Unimplemented"); - } - - // Enable interrupts - fn enable_interrupts() -> () { - unimplemented!("Unimplemented"); - } - - // Handle a packet reception - fn handle_receive() -> () { - unimplemented!("Unimplemented"); - } -} \ No newline at end of file diff --git a/kernel/src/network/ethernet.rs b/kernel/src/network/ethernet.rs index f52f9ff..9e30442 100644 --- a/kernel/src/network/ethernet.rs +++ b/kernel/src/network/ethernet.rs @@ -73,7 +73,7 @@ impl Layer for EthernetPacket { EthType::RoCE => LayerType::UNDEF, EthType::Unknown => LayerType::UNDEF, }; - return (packet, i, layer_type); + (packet, i, layer_type) } fn serialize(&self) -> Vec { diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index eb022e3..a4d7a6b 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -8,7 +8,7 @@ use crate::network::layer::{LayerType, Layer}; use crate::network::socket::RawSocket; use super::constants::{BROADCAST_ADDR, DHCP_CLIENT_PORT, BROADCAST_MAC}; -pub fn init() -> (){ +pub fn init(){ // todo bundle the init phases } @@ -59,5 +59,5 @@ pub fn init_dhcp(wait_timeout: u8) -> bool { rtl_dev_info.dhcp_server_ip = Some(dhcp_res.server_ip.val()); } enable_network_interrupts(); - return true; + true } diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index 25161d2..bbd3a24 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -25,10 +25,6 @@ impl Protocol { _ => Self::Unsupported, } } - - pub fn as_byte(&self) -> u8 { - *self as u8 - } } struct WrappedU16 { @@ -130,7 +126,7 @@ impl Layer for IPPacket { Protocol::RDP => LayerType::UNDEF, Protocol::Unsupported => LayerType::UNDEF, }; - return (packet, i, layer_type); + (packet, i, layer_type) } fn serialize(&self) -> Vec { @@ -142,7 +138,7 @@ impl Layer for IPPacket { res.extend(self.identification.data); res.extend(self.flags_fragment_offset.data); res.push(self.ttl); - res.push(self.protocol.as_byte()); + res.push(self.protocol as u8); res.extend(self.checksum.data); res.extend(self.source_ip.data); res.extend(self.destination_ip.data); @@ -156,7 +152,7 @@ impl Layer for IPPacket { } impl HasChecksum for IPPacket { - fn calculate_checksum(&mut self, data: &[u8]) -> () { + fn calculate_checksum(&mut self, data: &[u8]) { // Starting vars let mut sum: u32 = 0; @@ -180,11 +176,9 @@ impl HasChecksum for IPPacket { sum = (sum & 0xFFFF) + (sum >> 16); } - // One's complement - let mut res = !sum as u16; - // Swap the bytes because we did our sum in big endian + // One's complement and swap the bytes because we did our sum in big endian // (and the bytefield will try to convert to big endian) - res = ((res >> 8) & 0xFF) | ((res & 0xFF) << 8); + let res = u16::swap_bytes(!sum as u16); // Return the one's complement of sum self.checksum = Bytefield16::new(res); diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index ce42f4f..5ac71a8 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -44,7 +44,7 @@ impl Layer for EmptyLayer { pub trait HasChecksum { /// Calculate the checksum and self mutate - fn calculate_checksum(&mut self, data: &[u8]) -> (); + fn calculate_checksum(&mut self, data: &[u8]); } #[derive(Debug, PartialEq, Eq)] @@ -123,7 +123,7 @@ impl PacketData { } } -pub fn full_parse(packet: &Vec) -> (usize, PacketData) { +pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { let mut i = 0; let mut last_layer = PacketData::UNDEF(EmptyLayer::new()); let mut next_type = LayerType::ETH; diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs index ecc7d41..6581322 100644 --- a/kernel/src/network/netsync.rs +++ b/kernel/src/network/netsync.rs @@ -1,8 +1,6 @@ use spin::MutexGuard; -use crate::println; - -use super::rtl8139::{RTL8139, disable_network_interrupts, enable_network_interrupts}; +use super::rtl8139::RTL8139; pub struct NetworkInterruptsGuard<'a> { data: MutexGuard<'a, Option> @@ -50,13 +48,13 @@ pub struct InterruptCounter { } impl InterruptCounter { pub fn get(&self) -> u32 { - return self.data; + self.data } - pub fn inc(&mut self) -> () { + pub fn inc(&mut self) { self.data += 1; } - pub fn dec(&mut self) -> () { + pub fn dec(&mut self) { self.data -= 1; } } diff --git a/kernel/src/network/raw_array.rs b/kernel/src/network/raw_array.rs index f00cec3..6808dfb 100644 --- a/kernel/src/network/raw_array.rs +++ b/kernel/src/network/raw_array.rs @@ -60,7 +60,7 @@ impl WrappingRawArray { } // Ignore values - pub fn shift_amount(&mut self, amount: usize) -> () { + pub fn shift_amount(&mut self, amount: usize) { self.pos = (self.pos + amount) % self.size; } @@ -91,7 +91,7 @@ impl WrappingRawArray { impl Index for WrappingRawArray { type Output = u8; /// Index into the infinite array using raw pointers - fn index<'a>(&'a self, i: usize) -> &u8 { + fn index(&self, i: usize) -> &u8 { unsafe { &(*self.start.add(i % self.size)) } } } diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 3104c5a..b8daed1 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -53,7 +53,7 @@ lazy_static! { static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = spin::Mutex::new(InterruptCounter { data: 0 }); // Disable network interrupts (is thread safe) -pub fn disable_network_interrupts() -> () { +pub fn disable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; if data.get() == 0 { let mut port_imr = Port::::new((unsafe { IO_BASE } + 0x3C) as u16); @@ -63,7 +63,7 @@ pub fn disable_network_interrupts() -> () { } // Enable network interrupts (is thread safe) -pub fn enable_network_interrupts() -> () { +pub fn enable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; data.dec(); if data.get() == 0 { @@ -72,8 +72,7 @@ pub fn enable_network_interrupts() -> () { } } -pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) -> () { - println!("[INTERRUPT] - "); +pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) { // Try to get the device info let mut net_dev = NET_INFO.lock_no_disable(); if net_dev.is_none() { @@ -108,7 +107,7 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) if status & ROK != 0x0 { // println!("Receiving packet"); // Received packet - let pkt = recv_packet(&rtl_dev_info); + let pkt = recv_packet(rtl_dev_info); match pkt { PacketData::ARP(arp) => { // todo: also check for broadcast @@ -179,7 +178,8 @@ fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { // Make sure buffer isn't empty let cmd_reg = (rtl_dev_info.config.io_base.unwrap() + CR) as u16; let mut cmd_port = Port::::new(cmd_reg); - while unsafe { cmd_port.read() } & CR_BUFE == 0x0 { + // while unsafe { cmd_port.read() } & CR_BUFE == 0x0 { + if unsafe { cmd_port.read() } & CR_BUFE == 0x0 { // Receive a packet by reading the buffer // ? Reading the buffer is naturally unsafe? Is there a better way? let virtual_buffer_recv: VirtAddr = VirtAddr::new( @@ -200,7 +200,7 @@ fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { let packet = rx_buffer.trim((length - 4) as usize); // ? throw out the crc... we don't need to check it... rx_buffer.shift_amount(4); - let amount_parsed_and_pkt = full_parse(&packet); + let amount_parsed_and_pkt = full_parse(packet.as_slice()); // the amount we parse will be equal to length unless we are under the minimum assert!(amount_parsed_and_pkt.0 == (length - 4) as usize || length >= 64); @@ -217,10 +217,10 @@ fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { return amount_parsed_and_pkt.1; } else { unsafe { RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; } - break; + // break; } } - return PacketData::UNDEF(EmptyLayer::new()); + PacketData::UNDEF(EmptyLayer::new()) } // TODO: Split the driver into separate bits so we can lock individual resources? @@ -278,7 +278,7 @@ impl RTL8139 { true => println!("[INFO] Set up RTL8139 successful!!"), false => println!("[ERR] Set up RTL8139 failed!!") }; - return setup_status; + setup_status } fn new(configs: Vec) -> Option { @@ -291,12 +291,12 @@ impl RTL8139 { use_dev = Some(dev.clone()); } } - if use_dev.is_some() { + if let Some(device) = use_dev { // set the io base for enabling and disabling interrupts - unsafe { IO_BASE = use_dev.as_ref().unwrap().io_base.unwrap() as usize }; + unsafe { IO_BASE = device.io_base.unwrap() as usize }; // Return the device and a 12KB physical region return Some(RTL8139 { - config: use_dev.unwrap(), + config: device, recv_buffer: None, send_buffer: None, my_ip_address: None, @@ -351,13 +351,13 @@ impl RTL8139 { .register_irq(self.config.irq.unwrap() as usize, network_handle); // Get MAC address - let mac_addr = self.config.io_base.unwrap() + 0x00; + let mac_addr = self.config.io_base.unwrap(); // + 0 offset let mut mac1 = Port::::new(mac_addr as u16); let mut mac2 = Port::::new((mac_addr + 0x04) as u16); let mut mac_addr = 0x0_u64; unsafe { - mac_addr = mac_addr | mac1.read() as u64; - mac_addr = mac_addr | ((mac2.read() as u64) << 32); + mac_addr |= mac1.read() as u64; + mac_addr |= (mac2.read() as u64) << 32; }; // save mac address self.mac_address = Some(mac_addr); @@ -406,7 +406,7 @@ impl RTL8139 { // Enable interrupts // How I understood why: https://forum.osdev.org/viewtopic.php?f=1&t=27901 InterruptHandler::unblock_irq(self.config.irq.unwrap()); - return true; + true } // todo: remove this from driver code (add to socket.rs) @@ -421,10 +421,10 @@ impl RTL8139 { // wait for response // recursively try again - return 0; + 0 } - pub fn send_packet(&self, packet_data: &Vec) -> () { + pub fn send_packet(&self, packet_data: &Vec) { if self.send_buffer.is_none() || self.physical_mem_offset.is_none() { panic!("RTL8139 is not initialized properly"); } @@ -438,8 +438,8 @@ impl RTL8139 { VirtAddr::new(self.send_buffer.unwrap().as_u64() + self.physical_mem_offset.unwrap()); let virtual_buffer_ptr: *mut u8 = virtual_buffer.as_mut_ptr(); // Copy data into buffer - for i in 0..packet_data.len() { - unsafe { *(virtual_buffer_ptr.wrapping_add(i)) = packet_data[i] }; + for (i, item) in packet_data.iter().enumerate() { + unsafe { *(virtual_buffer_ptr.wrapping_add(i)) = *item }; } // TODO: Make this part of self... @@ -455,7 +455,7 @@ impl RTL8139 { // Send the packet from the buffer unsafe { TRANSMIT_IDX += 1; - TRANSMIT_IDX = TRANSMIT_IDX % 4; + TRANSMIT_IDX %= 4; }; } } diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 326c63b..98ccd76 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -1,5 +1,3 @@ -use crate::println; - use super::{rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}, layer::PacketData}; #[derive(Debug)] @@ -34,9 +32,9 @@ impl RawSocket { let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); match rtl_dev_info.ports.get_mut(&self.port) { Some(vec) => { - return vec.pop_front(); + vec.pop_front() }, - None => return None, + None => None, } } @@ -45,7 +43,7 @@ impl RawSocket { disable_network_interrupts(); let pkt = self.try_get_packet_inner(); enable_network_interrupts(); - return pkt; + pkt } // Query a port for a packet. Will block until a packet arrives @@ -53,32 +51,26 @@ impl RawSocket { let pkt; loop { x86_64::instructions::hlt(); - match self.try_get_packet() { - Some(next_pkt) => { - pkt = next_pkt; - break; - }, - None => {}, + if let Some(next_pkt) = self.try_get_packet() { + pkt = next_pkt; + break; } } - return pkt; + pkt } // Query a port for a packet. Will block until a packet arrives pub fn get_packet_with_timeout(&self, timeout_s: u32) -> Option { let mut pkt = None; for _ in 0..(18 * timeout_s) { - println!("One"); pkt = self.try_get_packet(); if pkt.is_some() { break; } - println!("Here?"); x86_64::instructions::hlt(); - println!("Two?"); } - return pkt; + pkt } - pub fn close(&mut self) -> () { + pub fn close(&mut self) { disable_network_interrupts(); let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index a1db6bd..08f620e 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -69,7 +69,7 @@ impl Layer for UDPPacket { LayerType::UNDEF }, }; - return (packet, i, layer_type); + (packet, i, layer_type) } fn serialize(&self) -> alloc::vec::Vec { @@ -90,7 +90,7 @@ impl Layer for UDPPacket { } impl HasChecksum for UDPPacket { - fn calculate_checksum(&mut self, data: &[u8]) -> () { + fn calculate_checksum(&mut self, data: &[u8]) { // Starting vars let mut sum: u32 = 0; self.length.swap_endianness(); @@ -105,7 +105,7 @@ impl HasChecksum for UDPPacket { sum += (ip.destination_ip.data[2] as u32) | (ip.destination_ip.data[3] as u32) << 8; // Sum protocol and length - let protocol = Bytefield16::new(ip.protocol.as_byte() as u16); + let protocol = Bytefield16::new(ip.protocol as u16); sum += (protocol.data[0] as u32) | (protocol.data[1] as u32) << 8; sum += (self.length.data[0] as u32) | (self.length.data[1] as u32) << 8; @@ -128,11 +128,9 @@ impl HasChecksum for UDPPacket { sum = (sum & 0xFFFF) + (sum >> 16); } - // One's complement - let mut res = !sum as u16; - // Swap the bytes because we did our sum in big endian + // One's complement and swap the bytes because we did our sum in big endian // (and the bytefield will try to convert to big endian) - res = ((res >> 8) & 0xFF) | ((res & 0xFF) << 8); + let res = u16::swap_bytes(!sum as u16); // Return the one's complement of sum self.checksum = Bytefield16::new(res); diff --git a/src/main.rs b/src/main.rs index fdc19f5..e57aace 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ fn main() { let bios_path = env!("BIOS_PATH"); // choose whether to start the UEFI or BIOS image - let uefi = true; + let uefi = false; let mut cmd = std::process::Command::new("qemu-system-x86_64"); // ---- networking related arguments ---- // From a97c3c0c73011bead1af6dcbada4ca3cafb67f44 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 31 Oct 2023 18:38:46 -0400 Subject: [PATCH 03/36] Fixing more issues + formatting --- kernel/src/interrupts.rs | 30 ++- kernel/src/main.rs | 56 ++++- kernel/src/network/arp.rs | 6 +- kernel/src/network/arp_table.rs | 2 +- kernel/src/network/bytefield.rs | 315 ++++--------------------- kernel/src/network/command_register.rs | 14 +- kernel/src/network/constants.rs | 2 +- kernel/src/network/devices.rs | 29 ++- kernel/src/network/dhcp.rs | 52 ++-- kernel/src/network/init.rs | 29 ++- kernel/src/network/ip.rs | 31 ++- kernel/src/network/layer.rs | 60 +++-- kernel/src/network/mod.rs | 12 +- kernel/src/network/netsync.rs | 9 +- kernel/src/network/raw_array.rs | 7 +- kernel/src/network/rtl8139.rs | 72 ++++-- kernel/src/network/socket.rs | 17 +- kernel/src/network/udp.rs | 16 +- src/main.rs | 11 +- 19 files changed, 325 insertions(+), 445 deletions(-) diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index c65d9f1..6515736 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -1,17 +1,16 @@ -use lazy_static::lazy_static; -use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame}; +use crate::hlt_loop; use crate::{gdt, println}; +use lazy_static::lazy_static; use pic8259::ChainedPics; use spin; use x86_64::structures::idt::PageFaultErrorCode; -use crate::hlt_loop; +use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame}; pub const PIC_1_OFFSET: u8 = 32; pub const PIC_2_OFFSET: u8 = PIC_1_OFFSET + 8; -pub static PICS: spin::Mutex = spin::Mutex::new({ - unsafe { ChainedPics::new(PIC_1_OFFSET, PIC_2_OFFSET) } -}); +pub static PICS: spin::Mutex = + spin::Mutex::new(unsafe { ChainedPics::new(PIC_1_OFFSET, PIC_2_OFFSET) }); #[derive(Debug, Clone, Copy)] #[repr(u8)] @@ -31,17 +30,18 @@ impl InterruptIndex { } pub struct InterruptHandler { - idt: InterruptDescriptorTable + idt: InterruptDescriptorTable, } -pub type InterruptHandlerFunc = extern "x86-interrupt" fn (InterruptStackFrame) -> (); +pub type InterruptHandlerFunc = extern "x86-interrupt" fn(InterruptStackFrame) -> (); impl InterruptHandler { pub fn new() -> Self { let mut idt = InterruptDescriptorTable::new(); idt.breakpoint.set_handler_fn(breakpoint_handler); idt.page_fault.set_handler_fn(page_fault_handler); unsafe { - idt.double_fault.set_handler_fn(double_fault_handler) + idt.double_fault + .set_handler_fn(double_fault_handler) .set_stack_index(gdt::DOUBLE_FAULT_IST_INDEX); } idt[InterruptIndex::Timer.as_usize()].set_handler_fn(timer_interrupt_handler); @@ -49,7 +49,7 @@ impl InterruptHandler { InterruptHandler { idt } } - pub fn init(&self){ + pub fn init(&self) { unsafe { self.idt.load_unsafe() }; } @@ -59,7 +59,7 @@ impl InterruptHandler { let data = unsafe { locked_pics.read_masks() }; // set the irq bit to 0 if irq_num < 8 { - unsafe { locked_pics.write_masks(data[0] & !(1 << irq_num), data[1]) }; + unsafe { locked_pics.write_masks(data[0] & !(1 << irq_num), data[1]) }; } else { unsafe { locked_pics.write_masks(data[0], data[1] & !(1 << (irq_num - 8))) }; } @@ -71,13 +71,13 @@ impl InterruptHandler { let data = unsafe { locked_pics.read_masks() }; // set the irq bit to 1 if irq_num < 8 { - unsafe { locked_pics.write_masks(data[0] | 1 << irq_num, data[1]) }; + unsafe { locked_pics.write_masks(data[0] | 1 << irq_num, data[1]) }; } else { unsafe { locked_pics.write_masks(data[0], data[1] | 1 << (irq_num - 8)) }; } } - pub fn register_irq(&mut self, irq_num: usize, handler: InterruptHandlerFunc){ + pub fn register_irq(&mut self, irq_num: usize, handler: InterruptHandlerFunc) { println!("Registered Handler @ {}", irq_num + 32); self.idt[irq_num + 32].set_handler_fn(handler); unsafe { self.idt.load_unsafe() }; @@ -86,9 +86,7 @@ impl InterruptHandler { } lazy_static! { - pub static ref IDT: spin::Mutex = { - spin::Mutex::new(InterruptHandler::new()) - }; + pub static ref IDT: spin::Mutex = spin::Mutex::new(InterruptHandler::new()); } pub fn init_idt() { diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 0de8cc6..7449126 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -12,8 +12,16 @@ use bootloader_api::{ use core::panic::PanicInfo; use kernel::{ framebuffer, hlt_loop, - network::{ethernet::{EthernetPacket, self}, udp::UDPPacket, ip::{IPPacket, Protocol}, layer::{Layer, HasChecksum, LayerType}, rtl8139::{disable_network_interrupts, NET_INFO, enable_network_interrupts}, socket::RawSocket, init::init_dhcp}, - println, print, + network::{ + ethernet::{self, EthernetPacket}, + init::init_dhcp, + ip::{IPPacket, Protocol}, + layer::{HasChecksum, Layer, LayerType}, + rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, + socket::RawSocket, + udp::UDPPacket, + }, + print, println, task::keyboard, task::{executor::Executor, Task}, }; @@ -73,7 +81,9 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { if rtl_driver.is_none() { panic!("Cannot find network card"); } - rtl_driver.unwrap().init(&mut frame_allocator, phys_mem_offset) + rtl_driver + .unwrap() + .init(&mut frame_allocator, phys_mem_offset) }; // so that the NET INFO gets released let status_init_dhcp = init_dhcp(2); if !status_init { @@ -86,7 +96,9 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { Ok(mut socket) => { loop { let pkt = socket.get_packet(); - if pkt.get_type() != LayerType::UDP { break; } + if pkt.get_type() != LayerType::UDP { + break; + } let udp_pkt = pkt.unwrap_udp(); let data_cloned = udp_pkt.data.clone(); let data_cloned_len = data_cloned.len(); @@ -99,24 +111,46 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { // send back a copy of the packet ("echo") let ip_layer_res = udp_pkt.ip_packet; let eth_layer_res = ip_layer_res.ethernet_packet; - let eth_layer = EthernetPacket::gen(eth_layer_res.src_mac.val(), eth_layer_res.dest_mac.val(), ethernet::EthType::IPv4); + let eth_layer = EthernetPacket::gen( + eth_layer_res.src_mac.val(), + eth_layer_res.dest_mac.val(), + ethernet::EthType::IPv4, + ); let udp_size = UDPPacket::packet_size() + data_cloned_len as u16; - let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::UDP, ip_layer_res.destination_ip.val(), ip_layer_res.source_ip.val()); - let mut udp_layer = UDPPacket::gen(ip_layer, udp_pkt.dest_port.val(), udp_pkt.src_port.val(), data_cloned_len as u16); + let ip_layer = IPPacket::gen( + eth_layer, + udp_size, + Protocol::UDP, + ip_layer_res.destination_ip.val(), + ip_layer_res.source_ip.val(), + ); + let mut udp_layer = UDPPacket::gen( + ip_layer, + udp_pkt.dest_port.val(), + udp_pkt.src_port.val(), + data_cloned_len as u16, + ); udp_layer.data = data_cloned; let data_2_send = udp_layer.serialize(); - let start_udp = data_2_send.len() - (UDPPacket::packet_size() as usize + data_cloned_len); + let start_udp = + data_2_send.len() - (UDPPacket::packet_size() as usize + data_cloned_len); let start_ip = start_udp - (IPPacket::packet_size() as usize); - udp_layer.ip_packet.calculate_checksum(&data_2_send[start_ip..start_udp]); + udp_layer + .ip_packet + .calculate_checksum(&data_2_send[start_ip..start_udp]); udp_layer.calculate_checksum(&data_2_send[start_udp..]); let data_2_send_final = udp_layer.serialize(); disable_network_interrupts(); - NET_INFO.lock().get_ref().unwrap().send_packet(&data_2_send_final); + NET_INFO + .lock() + .get_ref() + .unwrap() + .send_packet(&data_2_send_final); enable_network_interrupts(); } println!("[INFO] Socket is closing"); socket.close(); - }, + } Err(err) => println!("{:?}", err), } } diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs index e287087..58f4559 100644 --- a/kernel/src/network/arp.rs +++ b/kernel/src/network/arp.rs @@ -1,6 +1,6 @@ use super::{ bytefield::{Bytefield16, Bytefield32, Bytefield48, Bytefield8}, - ethernet::{EthernetPacket, EthType}, + ethernet::{EthType, EthernetPacket}, layer::{Layer, LayerType}, }; use alloc::vec; @@ -65,8 +65,8 @@ impl Layer for ArpPacket { packet.ethernet_packet = eth_layer; packet.hardware_type = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.protocol_type = Bytefield16::read_inc(&bytevec[i..], &mut i); - packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).data; - packet.protocol_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); + packet.protocol_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.operation = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.sender_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); packet.sender_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); diff --git a/kernel/src/network/arp_table.rs b/kernel/src/network/arp_table.rs index e4cd1fd..02b0835 100644 --- a/kernel/src/network/arp_table.rs +++ b/kernel/src/network/arp_table.rs @@ -2,4 +2,4 @@ pub struct ArpEntry { pub mac: u64, pub ip: u32, pub expires: u16, -} \ No newline at end of file +} diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs index 624d9f4..9ecb0dd 100644 --- a/kernel/src/network/bytefield.rs +++ b/kernel/src/network/bytefield.rs @@ -1,304 +1,77 @@ use core::ops::{Index, IndexMut}; -// N.B.: BytefieldS STORE IN BIG ENDIAN (as per network requirements) +// N.B.: Bytefields will swap the endianness of the values when created +// --> therefore serializing will create network byte order (big endian) Bytefields +// --> AND deserializing will create host byte order (small endian) Bytefields +// (That's why val() doesn't swap the byte order!) +// todo: refactor the api to track the state of the byte order? (would this work?) #[derive(Debug, Clone, Copy)] -pub struct Bytefield128 { - pub data: [u8; 16], +pub struct Bytefield { + pub data: [u8; N], } -impl Bytefield128 { - pub fn new(val: u128) -> Self { - let mut data = [0; Self::size()]; - for i in 0..Self::size() { - data[i] = (val >> ((Self::size() - 1 - i) * 8) & 0xFF) as u8; - } - Bytefield128 { data } - } - - pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { - *i += Self::size(); - let mut data = [0; Self::size()]; - for i in 0..Self::size() { - data[i] = bytevec[Self::size() - 1 - i]; - } - Bytefield128 { data } - } - - // Get the value in swapped endianness (for example, if parsing a web thing, you'll get the little endian version) - pub fn val(&self) -> u128 { - let mut res = 0_u128; - for i in 0..Self::size() { - res |= (self.data[i] as u128) << (i * 8); - } - res - } - - // Swap the endianness of the data - pub fn swap_endianness(&mut self) { - self.data.reverse(); - } - - // Get the number of bytes - pub const fn size() -> usize { - 16 - } -} - -#[derive(Clone, Copy)] -pub struct Bytefield64 { - pub data: [u8; 8], -} - -impl Bytefield64 { - pub fn new(val: u64) -> Self { - let mut data = [0; Self::size()]; - for i in 0..Self::size() { - data[i] = (val >> ((Self::size() - 1 - i) * 8) & 0xFF) as u8; - } - Self { data } - } - - pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { - *i += Self::size(); - let mut data = [0; Self::size()]; - for i in 0..Self::size() { - data[i] = bytevec[Self::size() - 1 - i]; - } - Self { data } - } - - // Get the value in swapped endianness (for example, if parsing a web thing, you'll get the little endian version) - pub fn val(&self) -> u64 { - let mut res = 0_u64; - for i in 0..Self::size() { - res |= (self.data[i] as u64) << (i * 8); - } - res - } - - // Swap the endianness of the data - pub fn swap_endianness(&mut self) { - self.data.reverse(); - } - - // Get the number of bytes - pub const fn size() -> usize { - 8 - } -} - -#[derive(Debug, Clone, Copy)] -pub struct Bytefield48 { - pub data: [u8; 6], -} - -impl Bytefield48 { - pub fn new(val: u64) -> Self { - let mut data = [0; Self::size()]; - for i in 0..Self::size() { - data[i] = (val >> ((Self::size() - 1 - i) * 8) & 0xFF) as u8; - } - Self { data } - } - - pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { - *i += Self::size(); - let mut data = [0; Self::size()]; - for i in 0..Self::size() { - data[i] = bytevec[Self::size() - 1 - i]; - } - Self { data } - } - - // Get the value in swapped endianness - pub fn val(&self) -> u64 { - let mut res = 0_u64; - for i in 0..Self::size() { - res |= (self.data[i] as u64) << (i * 8); - } - res - } - - // Swap the endianness of the data - pub fn swap_endianness(&mut self) { - self.data.reverse(); - } - - // Get the number of bytes - pub const fn size() -> usize { - 6 - } -} - -#[derive(Debug, Clone, Copy)] -pub struct Bytefield32 { - pub data: [u8; 4], -} - -impl Bytefield32 { - pub fn new(val: u32) -> Self { - let mut data = [0; Self::size()]; - for i in 0..Self::size() { - data[i] = (val >> ((Self::size() - 1 - i) * 8) & 0xFF) as u8; - } +impl Bytefield { + pub fn swapped_endianness(self) -> Self { + let mut data = self.data; + data.reverse(); Self { data } } pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { - *i += Self::size(); - let mut data = [0; Self::size()]; - for i in 0..Self::size() { - data[i] = bytevec[Self::size() - 1 - i]; + let mut data = [0_u8; N]; + for i in 0..N { + data[i] = bytevec[N - 1 - i]; } + *i += N; Self { data } } - // Get the value in swapped endianness (for example, if parsing a web thing, you'll get the little endian version) - pub fn val(&self) -> u32 { - let mut res = 0_u32; - for i in 0..Self::size() { - res |= (self.data[i] as u32) << (i * 8); - } - res - } - - // Swap the endianness of the data - pub fn swap_endianness(&mut self) { - self.data.reverse(); - } - - // Get the number of bytes pub const fn size() -> usize { - 4 + N } } -#[derive(Debug, Clone, Copy)] -pub struct Bytefield16 { - pub data: [u8; 2], -} - -impl Bytefield16 { - // Create a bytefield and swap endian-ness - pub fn new(val: u16) -> Self { - Self { data: [(val >> 8 & 0xFF) as u8, (val & 0xFF) as u8] } - } - - pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { - *i += Self::size(); - Self { data: [bytevec[1], bytevec[0]] } - } - - /// Get the original value used to create the bytefield (will preserve the endian-ness or the parsed data) - pub fn val(&self) -> u16 { - let mut res = 0_u16; - for i in 0..Self::size() { - res |= (self.data[i] as u16) << (i * 8); - } - res - } - - // Swap the endianness of the data - pub fn swap_endianness(&mut self) { - self.data.reverse(); - } - - pub fn size() -> usize { - 2 - } -} - -#[derive(Clone, Copy)] -pub struct Bytefield8 { - pub data: u8, -} - -impl Bytefield8 { - pub fn new(data: u8) -> Self { - Bytefield8 { data } - } - - pub fn read(bytevec: &[u8]) -> Self { - Bytefield8 { data: bytevec[0] } - } - - pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { - *i += Self::size(); - Self::read(bytevec) - } - - pub fn val(&self) -> u8 { - self.data - } - - pub fn size() -> usize { - 1 - } -} - -// ===== DEFINING INDEXING OPERATIONS ====== // -impl Index for Bytefield128 { +impl Index for Bytefield { type Output = u8; fn index(&self, i: usize) -> &u8 { &self.data[i] } } -impl IndexMut for Bytefield128 { +impl IndexMut for Bytefield { fn index_mut(&mut self, i: usize) -> &mut u8 { &mut self.data[i] } } -impl Index for Bytefield64 { - type Output = u8; - fn index(&self, i: usize) -> &u8 { - &self.data[i] - } -} +macro_rules! bytefield_int { + ($t:ident, $int:ident, $size:literal) => { + pub type $t = Bytefield<$size>; -impl IndexMut for Bytefield64 { - fn index_mut(&mut self, i: usize) -> &mut u8 { - &mut self.data[i] - } -} - -impl Index for Bytefield48 { - type Output = u8; - fn index(&self, i: usize) -> &u8 { - &self.data[i] - } -} + impl $t { + pub fn new(val: $int) -> Self { + let mut data = [0; $size]; + for i in 0..$size { + data[i] = (val >> (($size - 1 - i) * 8) & 0xFF) as u8; + } + $t { data } + } -impl IndexMut for Bytefield48 { - fn index_mut(&mut self, i: usize) -> &mut u8 { - &mut self.data[i] - } -} - -impl Index for Bytefield32 { - type Output = u8; - fn index(&self, i: usize) -> &u8 { - &self.data[i] - } -} - -impl IndexMut for Bytefield32 { - fn index_mut(&mut self, i: usize) -> &mut u8 { - &mut self.data[i] - } -} - -impl Index for Bytefield16 { - type Output = u8; - fn index(&self, i: usize) -> &u8 { - &self.data[i] - } + pub fn val(&self) -> $int { + let mut res = 0; + for i in 0..$size { + res |= (self.data[i] as $int) << (i * 8); + } + return res; + } + } + }; } -impl IndexMut for Bytefield16 { - fn index_mut(&mut self, i: usize) -> &mut u8 { - &mut self.data[i] - } -} \ No newline at end of file +bytefield_int!(Bytefield8, u8, 1); +bytefield_int!(Bytefield16, u16, 2); +bytefield_int!(Bytefield32, u32, 4); +bytefield_int!(Bytefield48, u64, 6); +bytefield_int!(Bytefield64, u64, 8); +bytefield_int!(Bytefield128, u128, 16); diff --git a/kernel/src/network/command_register.rs b/kernel/src/network/command_register.rs index 5e29067..669ab81 100644 --- a/kernel/src/network/command_register.rs +++ b/kernel/src/network/command_register.rs @@ -17,7 +17,7 @@ impl CommandRegister { pub fn get_io_space_bit(&self) -> bool { (self.cr & 0x1) != 0 } - pub fn set_io_space_bit(&mut self, is_on: bool){ + pub fn set_io_space_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x1, false => self.cr &= !0x1, @@ -28,7 +28,7 @@ impl CommandRegister { pub fn get_memory_space_bit(&self) -> bool { (self.cr & 0x2) != 0 } - pub fn set_memory_space_bit(&mut self, is_on: bool){ + pub fn set_memory_space_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x2, false => self.cr &= !0x2, @@ -39,7 +39,7 @@ impl CommandRegister { pub fn get_bus_master_bit(&self) -> bool { (self.cr & 0x4) != 0 } - pub fn set_bus_master_bit(&mut self, is_on: bool){ + pub fn set_bus_master_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x4, false => self.cr &= !0x4, @@ -65,7 +65,7 @@ impl CommandRegister { pub fn get_parity_err_res_bit(&self) -> bool { (self.cr & 0x40) != 0 } - pub fn set_parity_err_res_bit(&mut self, is_on: bool){ + pub fn set_parity_err_res_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x40, false => self.cr &= !0x40, @@ -76,7 +76,7 @@ impl CommandRegister { pub fn get_serr_enable_bit(&self) -> bool { (self.cr & 0x100) != 0 } - pub fn set_serr_enable_bit(&mut self, is_on: bool){ + pub fn set_serr_enable_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x100, false => self.cr &= !0x100, @@ -92,10 +92,10 @@ impl CommandRegister { pub fn get_interrupt_disable_bit(&self) -> bool { (self.cr & 0x400) != 0 } - pub fn set_interrupt_disable_bit(&mut self, is_on: bool){ + pub fn set_interrupt_disable_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x400, false => self.cr &= !0x400, } } -} \ No newline at end of file +} diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index 6148ca4..9d8addc 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -21,4 +21,4 @@ pub const CR_TE: u8 = 0x04; // Transmitter Enable, enables transmitting pub const CR_BUFE: u8 = 0x01; // Rx buffer is empty pub const CR: u32 = 0x37; // command register (1byte) pub const CAPR: u32 = 0x38; // current address of packet read (2byte, C mode, initial value 0xFFF0) -pub const RX_READ_PTR_MASK: u16 = !0x3; // \ No newline at end of file +pub const RX_READ_PTR_MASK: u16 = !0x3; // diff --git a/kernel/src/network/devices.rs b/kernel/src/network/devices.rs index 9dac963..b0fc5dd 100644 --- a/kernel/src/network/devices.rs +++ b/kernel/src/network/devices.rs @@ -42,7 +42,7 @@ pub enum PCIClassCodes { NonEssentialInstrumentation, CoProcessor, Reserved, - Unassigned + Unassigned, } impl PCIClassCodes { @@ -70,18 +70,19 @@ impl PCIClassCodes { 0x13 => Self::NonEssentialInstrumentation, 0x40 => Self::CoProcessor, 0xFF => Self::Unassigned, - _ => Self::Reserved + _ => Self::Reserved, } } } // Write into the config address -fn create_confg_address(bus: u8, slot: u8, func: u8, offset: u8){ +fn create_confg_address(bus: u8, slot: u8, func: u8, offset: u8) { let lbus = bus as u32; let lslot = slot as u32; let lfunc = func as u32; - let address = (lbus << 16) | (lslot << 11) | (lfunc << 8) | ((offset as u32) & 0xFC) | 0x80000000_u32; + let address = + (lbus << 16) | (lslot << 11) | (lfunc << 8) | ((offset as u32) & 0xFC) | 0x80000000_u32; let mut port = Port::::new(CONFIG_ADDRESS); // Write the address unsafe { port.write(address) }; @@ -107,7 +108,9 @@ fn pci_config_write_word(bus: u8, slot: u8, func: u8, offset: u8, word: u16) { // Read the data let data: u32 = unsafe { port.read() }; let new_data = (data & 0xFFFF0000) | word as u32; - unsafe { port.write(new_data); } + unsafe { + port.write(new_data); + } } /// Check if a device exists at (bus, slot) @@ -144,7 +147,7 @@ fn pci_get_class_code(bus: u8, slot: u8) -> (PCIClassCodes, u8) { // Read the interrupt line of the PCI Configuration address space // "For the x86 architecture this register corresponds to the PIC IRQ numbers 0-15" and a value of 0xFF defines no connection fn pci_get_irq(bus: u8, slot: u8) -> Option { - let irq = (pci_config_read_dword(bus, slot, 0, 0x3C) & 0xFF) as u8; + let irq = (pci_config_read_dword(bus, slot, 0, 0x3C) & 0xFF) as u8; if irq == 0xFF { None } else { @@ -188,7 +191,7 @@ impl Device { pub fn read_command_register(&self) -> CommandRegister { pci_get_cmd_reg(self.bus, self.slot) } - + // Write to the command register of the device pub fn write_command_register(&self, new_val: CommandRegister) { pci_set_cmd_reg(self.bus, self.slot, new_val); @@ -202,12 +205,14 @@ pub fn scan_devices() -> Vec { for bus in 0..255 { for slot in 0..31 { match pci_check_vendor(bus, slot) { - Some(_) => { device_bus_slots.push((bus, slot)); } - None => continue + Some(_) => { + device_bus_slots.push((bus, slot)); + } + None => continue, } } } - + let mut results: Vec = Vec::new(); for bus_slot in device_bus_slots.iter() { let bus = bus_slot.0; @@ -225,9 +230,9 @@ pub fn scan_devices() -> Vec { class_code: class_subclass.0, sub_class: class_subclass.1, irq, - io_base + io_base, }); } results -} \ No newline at end of file +} diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 129f762..6db687f 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -1,9 +1,10 @@ use alloc::vec; use super::{ - bytefield::{Bytefield128, Bytefield16, Bytefield32, Bytefield8, Bytefield48}, - layer::{Layer, HasChecksum, LayerType}, - udp::UDPPacket, ip::IPPacket, + bytefield::{Bytefield128, Bytefield16, Bytefield32, Bytefield48, Bytefield8}, + ip::IPPacket, + layer::{HasChecksum, Layer, LayerType}, + udp::UDPPacket, }; struct WrappedU32 { @@ -77,27 +78,30 @@ impl DHCPPacket { } let mut dhcp = DHCPPacket { udp_packet, - op_code: 1, // 1 for is request - hardware_type: 1, // ethernet is 1 - hardware_address_length: 6, // corresponds with hardware address length - hops: 0, // must be 0 from client + op_code: 1, // 1 for is request + hardware_type: 1, // ethernet is 1 + hardware_address_length: 6, // corresponds with hardware address length + hops: 0, // must be 0 from client transaction_identifier: identification, // random ID - seconds: Bytefield16::new(0), // doesn't matter + seconds: Bytefield16::new(0), // doesn't matter flags: Bytefield16::new(if ip_address.is_none() { 0x1 } else { 0x0 }), // 1 indicates that I don't know my IP address - client_ip: Bytefield32::new(ip_address.unwrap_or(0)), // 0 since - my_ip: Bytefield32::new(0), // 4 bytes - server_ip: Bytefield32::new(0), // 4 bytes - gateway_ip: Bytefield32::new(0), // 4 bytes - client_hardware_address, // 16 bytes - sname: [0; 64], // 64 bytes - file: [0; 128], // 128 bytes - options: [0; 64], // todo: 64 bytes (can be more) - // 300 bytes total + client_ip: Bytefield32::new(ip_address.unwrap_or(0)), // 0 since + my_ip: Bytefield32::new(0), // 4 bytes + server_ip: Bytefield32::new(0), // 4 bytes + gateway_ip: Bytefield32::new(0), // 4 bytes + client_hardware_address, // 16 bytes + sname: [0; 64], // 64 bytes + file: [0; 128], // 128 bytes + options: [0; 64], // todo: 64 bytes (can be more) + // 300 bytes total }; let data = dhcp.serialize(); - let start_udp = data.len() - (DHCPPacket::packet_size() as usize + UDPPacket::packet_size() as usize); + let start_udp = + data.len() - (DHCPPacket::packet_size() as usize + UDPPacket::packet_size() as usize); let start_ip = start_udp - (IPPacket::packet_size() as usize); - dhcp.udp_packet.ip_packet.calculate_checksum(&data[start_ip..start_udp]); + dhcp.udp_packet + .ip_packet + .calculate_checksum(&data[start_ip..start_udp]); dhcp.udp_packet.calculate_checksum(&data[start_udp..]); dhcp } @@ -112,10 +116,10 @@ impl Layer for DHCPPacket { let mut packet = DHCPPacket::new(); // create an empty packet let mut i = 0; packet.udp_packet = udp_layer; - packet.op_code = Bytefield8::read_inc(&bytevec[i..], &mut i).data; - packet.hardware_type = Bytefield8::read_inc(&bytevec[i..], &mut i).data; - packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).data; - packet.hops = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.op_code = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); + packet.hardware_type = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); + packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); + packet.hops = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.transaction_identifier = Bytefield32::read_inc(&bytevec[i..], &mut i); packet.seconds = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.flags = Bytefield16::read_inc(&bytevec[i..], &mut i); @@ -159,4 +163,4 @@ impl Layer for DHCPPacket { fn packet_size() -> u16 { 300 } -} \ No newline at end of file +} diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index a4d7a6b..17931cc 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -1,14 +1,14 @@ +use super::constants::{BROADCAST_ADDR, BROADCAST_MAC, DHCP_CLIENT_PORT}; use crate::network::dhcp::DHCPPacket; -use crate::network::ethernet::{EthernetPacket, EthType}; +use crate::network::ethernet::{EthType, EthernetPacket}; use crate::network::ip::{IPPacket, Protocol}; -use crate::network::rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}; +use crate::network::layer::{Layer, LayerType}; +use crate::network::rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}; +use crate::network::socket::RawSocket; use crate::network::udp::UDPPacket; use crate::{network::constants::DHCP_SERVER_PORT, println}; -use crate::network::layer::{LayerType, Layer}; -use crate::network::socket::RawSocket; -use super::constants::{BROADCAST_ADDR, DHCP_CLIENT_PORT, BROADCAST_MAC}; -pub fn init(){ +pub fn init() { // todo bundle the init phases } @@ -21,10 +21,19 @@ pub fn init_dhcp(wait_timeout: u8) -> bool { let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); // send dhcp initial request - let eth = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::IPv4); + let eth = EthernetPacket::gen( + BROADCAST_MAC, + rtl_dev_info.mac_address.unwrap(), + EthType::IPv4, + ); let ip_size = DHCPPacket::packet_size() + UDPPacket::packet_size(); let ip = IPPacket::gen(eth, ip_size, Protocol::UDP, 0x0, BROADCAST_ADDR); - let udp = UDPPacket::gen(ip, DHCP_CLIENT_PORT, DHCP_SERVER_PORT, DHCPPacket::packet_size()); + let udp = UDPPacket::gen( + ip, + DHCP_CLIENT_PORT, + DHCP_SERVER_PORT, + DHCPPacket::packet_size(), + ); let dhcp = DHCPPacket::gen(udp, None, rtl_dev_info.mac_address.unwrap()); let packet_data = dhcp.serialize(); @@ -43,9 +52,9 @@ pub fn init_dhcp(wait_timeout: u8) -> bool { break; } timeout += 1; - if timeout == wait_timeout { + if timeout == wait_timeout { socket.close(); - return false; + return false; } } diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index bbd3a24..ca0ba0d 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -1,9 +1,10 @@ -use alloc::vec::Vec; use alloc::vec; +use alloc::vec::Vec; use super::{ - bytefield::{Bytefield8, Bytefield16, Bytefield32}, - layer::{Layer, HasChecksum, LayerType}, ethernet::EthernetPacket, + bytefield::{Bytefield16, Bytefield32, Bytefield8}, + ethernet::EthernetPacket, + layer::{HasChecksum, Layer, LayerType}, }; #[derive(Debug, Clone, Copy)] @@ -51,7 +52,7 @@ pub struct IPPacket { flags_fragment_offset: Bytefield16, // 2 bytes ttl: u8, // 1 byte pub protocol: Protocol, // 1 byte (public for checksumming) - pub checksum: Bytefield16, // 2 bytes + pub checksum: Bytefield16, // 2 bytes pub source_ip: Bytefield32, // 4 bytes (public for checksumming) pub destination_ip: Bytefield32, // 4 bytes (public for checksumming) // 20 bytes in total @@ -74,7 +75,13 @@ impl IPPacket { } } - pub fn gen(ethernet_packet: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dst_ip: u32) -> Self { + pub fn gen( + ethernet_packet: EthernetPacket, + data_length: u16, + protocol: Protocol, + src_ip: u32, + dst_ip: u32, + ) -> Self { let identification = unsafe { let mut id_gen = ID_GEN.lock(); let id_gen_old = id_gen.get(); @@ -104,17 +111,17 @@ impl Layer for IPPacket { Self: Sized, { let mut packet = IPPacket::new(); // create an empty packet - // Read 20 bytes + // Read 20 bytes let mut i = 0; packet.ethernet_packet = ethernet_layer; - packet.version_hlen = Bytefield8::read_inc(&bytevec[i..], &mut i).data; - packet.type_of_service = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.version_hlen = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); + packet.type_of_service = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.total_length = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.identification = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.flags_fragment_offset = Bytefield16::read_inc(&bytevec[i..], &mut i); - packet.ttl = Bytefield8::read_inc(&bytevec[i..], &mut i).data; + packet.ttl = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); let protocol = Bytefield8::read_inc(&bytevec[i..], &mut i); - packet.protocol = Protocol::from(protocol.data); + packet.protocol = Protocol::from(protocol.val()); packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.source_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); packet.destination_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); @@ -176,11 +183,11 @@ impl HasChecksum for IPPacket { sum = (sum & 0xFFFF) + (sum >> 16); } - // One's complement and swap the bytes because we did our sum in big endian + // One's complement and swap the bytes because we did our sum in big endian // (and the bytefield will try to convert to big endian) let res = u16::swap_bytes(!sum as u16); // Return the one's complement of sum self.checksum = Bytefield16::new(res); } -} \ No newline at end of file +} diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index 5ac71a8..b8828e6 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -11,7 +11,9 @@ pub trait Layer { type Input; /// Parse a bytevec (vector slice) and return the packet and the amount of data parsed - fn parse(upper: Self::Input, bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized; + fn parse(upper: Self::Input, bytevec: &[u8]) -> (Self, usize, LayerType) + where + Self: Sized; /// Convert a packet into a vector of bytes fn serialize(&self) -> Vec; @@ -23,13 +25,16 @@ pub trait Layer { pub struct EmptyLayer {} impl EmptyLayer { pub fn new() -> Self { - EmptyLayer { } + EmptyLayer {} } } impl Layer for EmptyLayer { type Input = EmptyLayer; - fn parse(_upper: EmptyLayer, _bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized { + fn parse(_upper: EmptyLayer, _bytevec: &[u8]) -> (Self, usize, LayerType) + where + Self: Sized, + { (Self {}, 0, LayerType::UNDEF) } @@ -49,14 +54,14 @@ pub trait HasChecksum { #[derive(Debug, PartialEq, Eq)] pub enum LayerType { - ETH, + ETH, IP, ARP, UDP, ICMP, DHCP, TCP, - UNDEF // the default layer type + UNDEF, // the default layer type } /// Wrapper type to allow me to return a generic @@ -69,7 +74,7 @@ pub enum PacketData { ICMP(EmptyLayer), DHCP(DHCPPacket), TCP(EmptyLayer), - UNDEF(EmptyLayer) + UNDEF(EmptyLayer), } impl PacketData { @@ -131,42 +136,53 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { match next_type { LayerType::ETH => { let last_layer_data = last_layer.unwrap_undef(); - let (eth_layer, size, network_layer_type) = EthernetPacket::parse(last_layer_data, &packet[i..]); + let (eth_layer, size, network_layer_type) = + EthernetPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::ETH(eth_layer); i += size; next_type = network_layer_type; - }, + } LayerType::IP => { let last_layer_data = last_layer.unwrap_eth(); - let (ip_layer, size, transport_layer_type) = IPPacket::parse(last_layer_data, &packet[i..]); + let (ip_layer, size, transport_layer_type) = + IPPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::IP(ip_layer); i += size; next_type = transport_layer_type; - }, + } LayerType::ARP => { let last_layer_data = last_layer.unwrap_eth(); - let (arp_layer, size, transport_layer_type) = ArpPacket::parse(last_layer_data, &packet[i..]); + let (arp_layer, size, transport_layer_type) = + ArpPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::ARP(arp_layer); i += size; next_type = transport_layer_type; - }, + } LayerType::UDP => { let last_layer_data = last_layer.unwrap_ip(); - let (udp_layer, size, application_layer_type) = UDPPacket::parse(last_layer_data, &packet[i..]); + let (udp_layer, size, application_layer_type) = + UDPPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::UDP(udp_layer); i += size; next_type = application_layer_type; - }, - LayerType::ICMP => { return (0, PacketData::UNDEF(EmptyLayer::new())); }, - LayerType::DHCP =>{ + } + LayerType::ICMP => { + return (0, PacketData::UNDEF(EmptyLayer::new())); + } + LayerType::DHCP => { let last_layer_data = last_layer.unwrap_udp(); - let (dhcp_layer, size, empty_type) = DHCPPacket::parse(last_layer_data, &packet[i..]); + let (dhcp_layer, size, empty_type) = + DHCPPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::DHCP(dhcp_layer); i += size; next_type = empty_type; - }, - LayerType::TCP => { return (0, PacketData::UNDEF(EmptyLayer::new())); }, - LayerType::UNDEF => { return (i, last_layer); }, + } + LayerType::TCP => { + return (0, PacketData::UNDEF(EmptyLayer::new())); + } + LayerType::UNDEF => { + return (i, last_layer); + } } - }; -} \ No newline at end of file + } +} diff --git a/kernel/src/network/mod.rs b/kernel/src/network/mod.rs index 4e6bd7b..0c3ae54 100644 --- a/kernel/src/network/mod.rs +++ b/kernel/src/network/mod.rs @@ -1,20 +1,20 @@ -pub mod bytefield; pub mod arp; +pub mod bytefield; pub mod command_register; pub mod devices; +pub mod dhcp; pub mod ethernet; +pub mod init; pub mod ip; pub mod layer; pub mod rtl8139; -pub mod udp; -pub mod dhcp; pub mod socket; -pub mod init; +pub mod udp; // todo: remove pub until things break... -mod raw_array; mod arp_table; -mod netsync; pub mod constants; +mod netsync; +mod raw_array; // pub mod e1000; /* diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs index 6581322..115f803 100644 --- a/kernel/src/network/netsync.rs +++ b/kernel/src/network/netsync.rs @@ -3,7 +3,7 @@ use spin::MutexGuard; use super::rtl8139::RTL8139; pub struct NetworkInterruptsGuard<'a> { - data: MutexGuard<'a, Option> + data: MutexGuard<'a, Option>, } impl NetworkInterruptsGuard<'_> { @@ -26,7 +26,7 @@ impl Drop for NetworkInterruptsGuard<'_> { } pub struct SafeRTL8139 { - data: spin::Mutex> + data: spin::Mutex>, } impl SafeRTL8139 { @@ -35,14 +35,15 @@ impl SafeRTL8139 { } pub fn lock(&self) -> NetworkInterruptsGuard { // disable_network_interrupts(); - return NetworkInterruptsGuard { data: self.data.lock() } + return NetworkInterruptsGuard { + data: self.data.lock(), + }; } pub fn lock_no_disable(&self) -> MutexGuard> { return self.data.lock(); } } - pub struct InterruptCounter { pub data: u32, } diff --git a/kernel/src/network/raw_array.rs b/kernel/src/network/raw_array.rs index 6808dfb..be4f8f1 100644 --- a/kernel/src/network/raw_array.rs +++ b/kernel/src/network/raw_array.rs @@ -1,6 +1,6 @@ -use core::ops::Index; use alloc::vec; use alloc::vec::Vec; +use core::ops::Index; // Leaving this here unless we change an implementation that necessiates a differnt type of array /*pub struct RawArray { @@ -42,7 +42,6 @@ impl Index for RawArray { } }*/ - pub struct WrappingRawArray { start: *const u8, pos: usize, @@ -53,7 +52,7 @@ impl WrappingRawArray { /// An infinite array beginning at "start" and wrapping after size bytes pub fn new(start: *const u8, size: usize) -> Self { WrappingRawArray { - start, + start, pos: 0, size, } @@ -85,7 +84,6 @@ impl WrappingRawArray { } res } - } impl Index for WrappingRawArray { @@ -95,4 +93,3 @@ impl Index for WrappingRawArray { unsafe { &(*self.start.add(i % self.size)) } } } - diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index b8daed1..93d6d77 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -1,5 +1,5 @@ -use alloc::{vec, collections::VecDeque}; use alloc::vec::Vec; +use alloc::{collections::VecDeque, vec}; use hashbrown::{HashMap, HashSet}; use lazy_static::lazy_static; @@ -14,22 +14,32 @@ use x86_64::{ use crate::interrupts::IDT; -use crate::network::constants::{RX_BUFFER_SIZE, CR, CR_RE, CR_TE, CR_BUFE, RX_READ_PTR_MASK, CAPR}; +use crate::network::constants::{ + CAPR, CR, CR_BUFE, CR_RE, CR_TE, RX_BUFFER_SIZE, RX_READ_PTR_MASK, +}; use crate::network::raw_array::WrappingRawArray; -use crate::{ - interrupts::{InterruptHandler, PICS}, - memory::BootInfoFrameAllocator, - network::{devices, netsync::SafeRTL8139, ethernet::{EthernetPacket, EthType}, arp::ArpPacket, layer::Layer}, - println, +use super::constants::{ + BROADCAST_ADDR, INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG, }; -use super::constants::{INTERRUPT_MASK, ROK, TOK, RTL_VEND, RTL_DEV, TRANSMIT_REG, TRANSMIT_CMD, BROADCAST_ADDR}; use super::{ arp_table::ArpEntry, devices::{Device, PCIClassCodes}, layer::{full_parse, EmptyLayer, PacketData}, netsync::InterruptCounter, }; +use crate::{ + interrupts::{InterruptHandler, PICS}, + memory::BootInfoFrameAllocator, + network::{ + arp::ArpPacket, + devices, + ethernet::{EthType, EthernetPacket}, + layer::Layer, + netsync::SafeRTL8139, + }, + println, +}; // ISR_ROK|ISR_TOK|ISR_RXOVW|ISR_TER|ISR_RER // ! URGENT: start checking other statuses for errors @@ -51,7 +61,8 @@ lazy_static! { }; } -static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = spin::Mutex::new(InterruptCounter { data: 0 }); +static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = + spin::Mutex::new(InterruptCounter { data: 0 }); // Disable network interrupts (is thread safe) pub fn disable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; @@ -113,8 +124,17 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) // todo: also check for broadcast if arp.recp_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { // println!("[INT-HANDLER] Send a response back"); - let eth_layer = EthernetPacket::gen(arp.sender_mac.val(), rtl_dev_info.mac_address.unwrap(), EthType::Arp); - let arp_layer = ArpPacket::gen(eth_layer, rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR), arp.sender_ip.val(), false); + let eth_layer = EthernetPacket::gen( + arp.sender_mac.val(), + rtl_dev_info.mac_address.unwrap(), + EthType::Arp, + ); + let arp_layer = ArpPacket::gen( + eth_layer, + rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR), + arp.sender_ip.val(), + false, + ); let arp_pkt = arp_layer.serialize(); rtl_dev_info.send_packet(&arp_pkt); } else { @@ -156,7 +176,7 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) .unwrap() .push_back(PacketData::UDP(udp)); } - }, + } _ => {} // ignore others } } @@ -206,18 +226,26 @@ fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { assert!(amount_parsed_and_pkt.0 == (length - 4) as usize || length >= 64); // after receiving the packet, update CAPR and RECV_POS // increment recv_pos - unsafe { RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; } - unsafe { RECV_POS = (RECV_POS + length) % RX_BUFFER_SIZE; } + unsafe { + RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; + } + unsafe { + RECV_POS = (RECV_POS + length) % RX_BUFFER_SIZE; + } // we and with RX_READ_PTR_MASK to ensure double word alignment // 4 for header length, 3 is also apparently for dword alignment... - unsafe { RECV_POS = ((RECV_POS + 4) & RX_READ_PTR_MASK) % RX_BUFFER_SIZE; } + unsafe { + RECV_POS = ((RECV_POS + 4) & RX_READ_PTR_MASK) % RX_BUFFER_SIZE; + } let mut capr = Port::::new((rtl_dev_info.config.io_base.unwrap() + CAPR) as u16); // println!("[RECV_POS] {}", unsafe { RECV_POS }); unsafe { capr.write(RECV_POS - 0x10) }; return amount_parsed_and_pkt.1; } else { - unsafe { RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; } - // break; + unsafe { + RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; + } + // break; } } PacketData::UNDEF(EmptyLayer::new()) @@ -240,7 +268,11 @@ pub struct RTL8139 { impl RTL8139 { // Initialize the card - pub fn init(&mut self, frame_allocator: &mut BootInfoFrameAllocator, physical_mem_offset: u64) -> bool { + pub fn init( + &mut self, + frame_allocator: &mut BootInfoFrameAllocator, + physical_mem_offset: u64, + ) -> bool { let mut frames: Vec = vec![]; // create recv and send buffer space for _ in 0..6 { @@ -276,7 +308,7 @@ impl RTL8139 { let setup_status = self.setup(); match setup_status { true => println!("[INFO] Set up RTL8139 successful!!"), - false => println!("[ERR] Set up RTL8139 failed!!") + false => println!("[ERR] Set up RTL8139 failed!!"), }; setup_status } @@ -461,4 +493,4 @@ impl RTL8139 { } // sudo qemu-system-x86_64 -M q35 -serial mon:stdio -nographic -netdev vmnet-bridged,id=net0,ifname=en0 -device rtl8139,netdev=net0,mac=00:11:22:33:44:55 -// \ No newline at end of file +// diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 98ccd76..7ad42c4 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -1,4 +1,7 @@ -use super::{rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}, layer::PacketData}; +use super::{ + layer::PacketData, + rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, +}; #[derive(Debug)] pub enum NetworkErrors { @@ -8,7 +11,7 @@ pub enum NetworkErrors { // todo: Implement a socket for user-space... pub struct RawSocket { - port: u16 + port: u16, } impl RawSocket { @@ -31,9 +34,7 @@ impl RawSocket { let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); match rtl_dev_info.ports.get_mut(&self.port) { - Some(vec) => { - vec.pop_front() - }, + Some(vec) => vec.pop_front(), None => None, } } @@ -64,7 +65,9 @@ impl RawSocket { let mut pkt = None; for _ in 0..(18 * timeout_s) { pkt = self.try_get_packet(); - if pkt.is_some() { break; } + if pkt.is_some() { + break; + } x86_64::instructions::hlt(); } pkt @@ -83,4 +86,4 @@ impl RawSocket { } enable_network_interrupts(); } -} \ No newline at end of file +} diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index 08f620e..95f4ebb 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -9,13 +9,13 @@ use alloc::vec::Vec; #[derive(Debug)] pub struct UDPPacket { - pub ip_packet: IPPacket, // public for checksumming + pub ip_packet: IPPacket, // public for checksumming pub src_port: Bytefield16, // 2 bytes pub dest_port: Bytefield16, // 2 bytes pub length: Bytefield16, // 2 bytes - checksum: Bytefield16, // 2 bytes - pub data: Vec, // a vector for data bytes if needed - // 10 bytes total + checksum: Bytefield16, // 2 bytes + pub data: Vec, // a vector for data bytes if needed + // 10 bytes total } impl UDPPacket { @@ -67,7 +67,7 @@ impl Layer for UDPPacket { } assert!(i == packet.length.val() as usize); LayerType::UNDEF - }, + } }; (packet, i, layer_type) } @@ -93,9 +93,8 @@ impl HasChecksum for UDPPacket { fn calculate_checksum(&mut self, data: &[u8]) { // Starting vars let mut sum: u32 = 0; - self.length.swap_endianness(); - let mut udp_len = self.length.val() as usize; - self.length.swap_endianness(); // swap back because we don't want to permanently mutate the length when calculating... + // calculating checksum on serialized bytefield (so its network byte order and must be swapped) + let mut udp_len = self.length.swapped_endianness().val() as usize; // First we do the IP as a pseduo header let ip = &self.ip_packet; @@ -136,4 +135,3 @@ impl HasChecksum for UDPPacket { self.checksum = Bytefield16::new(res); } } - diff --git a/src/main.rs b/src/main.rs index e57aace..3cc8739 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,14 +11,17 @@ fn main() { let mut cmd = std::process::Command::new("qemu-system-x86_64"); // ---- networking related arguments ---- // // related to issue with kernel driver freezing. see https://wiki.osdev.org/RTL8139 - cmd.arg("-machine").arg("kernel_irqchip=off"); + cmd.arg("-machine").arg("kernel_irqchip=off"); // Virtualizing a network, has an entry point to inject packets into the isolated network // Can send UDP from localhost:5555 to SOUP:5554 - cmd.arg("-netdev").arg("user,id=net0,hostfwd=udp::5555-:5554"); + cmd.arg("-netdev") + .arg("user,id=net0,hostfwd=udp::5555-:5554"); // Making sure we have the rtl8139 as a hardware resource - cmd.arg("-device").arg("rtl8139,netdev=net0,mac=00:11:22:33:44:55"); + cmd.arg("-device") + .arg("rtl8139,netdev=net0,mac=00:11:22:33:44:55"); // Recording the network traffic in order to conduct debugging and analysis - cmd.arg("-object").arg("filter-dump,id=f1,netdev=net0,file=/tmp/dump.pcap"); + cmd.arg("-object") + .arg("filter-dump,id=f1,netdev=net0,file=/tmp/dump.pcap"); if uefi { cmd.arg("-bios").arg(ovmf_prebuilt::ovmf_pure_efi()); From ea55b67feaeb2ac072a50cc38244c1469a618d8a Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 31 Oct 2023 18:42:08 -0400 Subject: [PATCH 04/36] Fixing more merge stuff --- kernel/src/network/bytefield.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs index f2581e7..1b4cb4a 100644 --- a/kernel/src/network/bytefield.rs +++ b/kernel/src/network/bytefield.rs @@ -182,7 +182,7 @@ pub struct Bytefield16 { impl Bytefield16 { // Create a bytefield and swap endian-ness pub fn new(val: u16) -> Self { - Self { data: [(val >> 1 * 8 & 0xFF) as u8, (val >> 0 * 8 & 0xFF) as u8] } + Self { data: [(val >> 1 * 8 & 0xFF) as u8, (val & 0xFF) as u8] } } pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { From d31589178a2b079e0bee8ccbb32b3a78077c5ed9 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 31 Oct 2023 18:43:40 -0400 Subject: [PATCH 05/36] Clippy fix --- kernel/src/interrupts.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index 6d82f3d..d986f3f 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -1,7 +1,7 @@ use crate::{gdt, hlt_loop}; use lazy_static::lazy_static; use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame, PageFaultErrorCode}; -use crate::{println, print}; +use crate::println; use pic8259::ChainedPics; use spin; From 55a93e1b55ec25f3de221ac8003226d3548831c3 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 31 Oct 2023 18:48:37 -0400 Subject: [PATCH 06/36] Small dhcp bug fix --- kernel/src/main.rs | 2 +- kernel/src/network/init.rs | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 7449126..25da3c8 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -85,7 +85,7 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { .unwrap() .init(&mut frame_allocator, phys_mem_offset) }; // so that the NET INFO gets released - let status_init_dhcp = init_dhcp(2); + let status_init_dhcp = init_dhcp(10); if !status_init { println!("[ERR] Cannot init RTL8139"); } else if !status_init_dhcp { diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index 17931cc..0f74bf6 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -50,6 +50,12 @@ pub fn init_dhcp(wait_timeout: u8) -> bool { println!("Got good packet"); pkt_data = dhcp_res; break; + } else { + disable_network_interrupts(); + let rtl_dev_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); + rtl_dev_info.send_packet(&packet_data); // send another packet + enable_network_interrupts(); } timeout += 1; if timeout == wait_timeout { From 8a59a99b9a0bb59ffc76208350bf138e023e84cb Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 7 Nov 2023 13:58:36 -0500 Subject: [PATCH 07/36] Async udp echo --- kernel/src/main.rs | 79 +++------------------------------ kernel/src/network/init.rs | 10 ++++- kernel/src/network/rtl8139.rs | 3 ++ kernel/src/network/socket.rs | 25 +++++++++++ kernel/src/task/mod.rs | 1 + kernel/src/task/udp_echo.rs | 82 +++++++++++++++++++++++++++++++++++ 6 files changed, 126 insertions(+), 74 deletions(-) create mode 100644 kernel/src/task/udp_echo.rs diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 25da3c8..d4042e6 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -4,7 +4,6 @@ #![test_runner(kernel::test_runner)] #![reexport_test_harness_main = "test_main"] -use alloc::string::String; use bootloader_api::{ config::{BootloaderConfig, Mapping}, entry_point, BootInfo, @@ -13,16 +12,12 @@ use core::panic::PanicInfo; use kernel::{ framebuffer, hlt_loop, network::{ - ethernet::{self, EthernetPacket}, init::init_dhcp, - ip::{IPPacket, Protocol}, - layer::{HasChecksum, Layer, LayerType}, - rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, - socket::RawSocket, - udp::UDPPacket, + rtl8139::NET_INFO, }, - print, println, + println, task::keyboard, + task::udp_echo, task::{executor::Executor, Task}, }; @@ -88,79 +83,19 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { let status_init_dhcp = init_dhcp(10); if !status_init { println!("[ERR] Cannot init RTL8139"); + hlt_loop(); } else if !status_init_dhcp { println!("[ERR] DHCP error -- whats my ip?"); - } else { - let raw_socket = RawSocket::new(5554); - match raw_socket { - Ok(mut socket) => { - loop { - let pkt = socket.get_packet(); - if pkt.get_type() != LayerType::UDP { - break; - } - let udp_pkt = pkt.unwrap_udp(); - let data_cloned = udp_pkt.data.clone(); - let data_cloned_len = data_cloned.len(); - let user_message = String::from_utf8(udp_pkt.data); - match user_message { - Ok(message) => print!("[USER] {}", message), - Err(err) => println!("[USER-ERR] {:?}", err), - } - - // send back a copy of the packet ("echo") - let ip_layer_res = udp_pkt.ip_packet; - let eth_layer_res = ip_layer_res.ethernet_packet; - let eth_layer = EthernetPacket::gen( - eth_layer_res.src_mac.val(), - eth_layer_res.dest_mac.val(), - ethernet::EthType::IPv4, - ); - let udp_size = UDPPacket::packet_size() + data_cloned_len as u16; - let ip_layer = IPPacket::gen( - eth_layer, - udp_size, - Protocol::UDP, - ip_layer_res.destination_ip.val(), - ip_layer_res.source_ip.val(), - ); - let mut udp_layer = UDPPacket::gen( - ip_layer, - udp_pkt.dest_port.val(), - udp_pkt.src_port.val(), - data_cloned_len as u16, - ); - udp_layer.data = data_cloned; - let data_2_send = udp_layer.serialize(); - let start_udp = - data_2_send.len() - (UDPPacket::packet_size() as usize + data_cloned_len); - let start_ip = start_udp - (IPPacket::packet_size() as usize); - udp_layer - .ip_packet - .calculate_checksum(&data_2_send[start_ip..start_udp]); - udp_layer.calculate_checksum(&data_2_send[start_udp..]); - let data_2_send_final = udp_layer.serialize(); - disable_network_interrupts(); - NET_INFO - .lock() - .get_ref() - .unwrap() - .send_packet(&data_2_send_final); - enable_network_interrupts(); - } - println!("[INFO] Socket is closing"); - socket.close(); - } - Err(err) => println!("{:?}", err), - } + hlt_loop(); } #[cfg(test)] test_main(); let mut executor = Executor::new(); - executor.spawn(Task::new(example_task())); executor.spawn(Task::new(keyboard::print_keypresses())); + executor.spawn(Task::new(udp_echo::udp_echo_server())); + executor.spawn(Task::new(example_task())); executor.run(); } diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index 0f74bf6..078c8cf 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -40,14 +40,12 @@ pub fn init_dhcp(wait_timeout: u8) -> bool { rtl_dev_info.send_packet(&packet_data); // send first packet drop(rtl_dev_guard); enable_network_interrupts(); - println!("Sent packet"); // get response let mut timeout = 0; let pkt_data; loop { if let Some(dhcp_res) = socket.get_packet_with_timeout(1) { - println!("Got good packet"); pkt_data = dhcp_res; break; } else { @@ -71,8 +69,16 @@ pub fn init_dhcp(wait_timeout: u8) -> bool { if pkt_data.get_type() == LayerType::DHCP { let dhcp_res = pkt_data.unwrap_dhcp(); rtl_dev_info.my_ip_address = Some(dhcp_res.my_ip.val()); + let ip = dhcp_res.my_ip.swapped_endianness(); + println!("[INFO] IP-Address Assigned As {}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]); rtl_dev_info.dhcp_server_ip = Some(dhcp_res.server_ip.val()); } enable_network_interrupts(); true } + +pub async fn process_packet_data() -> ! { + loop { + + } +} \ No newline at end of file diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 93d6d77..879f355 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -22,6 +22,7 @@ use crate::network::raw_array::WrappingRawArray; use super::constants::{ BROADCAST_ADDR, INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG, }; +use super::socket::wake_sockets; use super::{ arp_table::ArpEntry, devices::{Device, PCIClassCodes}, @@ -161,6 +162,7 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) .get_mut(&dst_port) .unwrap() .push_back(PacketData::DHCP(dhcp)); + wake_sockets(); } } PacketData::UDP(udp) => { @@ -175,6 +177,7 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) .get_mut(&dst_port) .unwrap() .push_back(PacketData::UDP(udp)); + wake_sockets(); } } _ => {} // ignore others diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 7ad42c4..67da8f6 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -1,7 +1,12 @@ +use futures_util::{Stream, task::AtomicWaker}; + use super::{ layer::PacketData, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, }; +use core::task::{Context, Poll}; + +static WAKER: AtomicWaker = AtomicWaker::new(); #[derive(Debug)] pub enum NetworkErrors { @@ -87,3 +92,23 @@ impl RawSocket { enable_network_interrupts(); } } + +impl Stream for RawSocket { + type Item = PacketData; + + fn poll_next(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + disable_network_interrupts(); + let pkt = self.try_get_packet_inner(); + enable_network_interrupts(); + WAKER.register(cx.waker()); + if pkt.is_some() { + Poll::Ready(pkt) + } else { + Poll::Pending + } + } +} + +pub(crate) fn wake_sockets() { + WAKER.wake(); +} \ No newline at end of file diff --git a/kernel/src/task/mod.rs b/kernel/src/task/mod.rs index 255356f..533dbc6 100644 --- a/kernel/src/task/mod.rs +++ b/kernel/src/task/mod.rs @@ -6,6 +6,7 @@ use core::{future::Future, pin::Pin}; pub mod executor; pub mod keyboard; pub mod simple_executor; +pub mod udp_echo; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] struct TaskId(u64); diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs new file mode 100644 index 0000000..3af7c83 --- /dev/null +++ b/kernel/src/task/udp_echo.rs @@ -0,0 +1,82 @@ +use crate::{ + network::{ + ethernet::{self, EthernetPacket}, + ip::{IPPacket, Protocol}, + layer::{HasChecksum, Layer, LayerType}, + rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, + socket::RawSocket, + udp::UDPPacket, + }, + print, println, +}; +use alloc::string::String; +use futures_util::StreamExt; + +pub async fn udp_echo_server() { + let raw_socket = RawSocket::new(5554); + match raw_socket { + Ok(mut socket) => { + while let Some(pkt) = socket.next().await { + if pkt.get_type() != LayerType::UDP { + break; + } + let udp_pkt = pkt.unwrap_udp(); + let data_cloned = udp_pkt.data.clone(); + let data_cloned_len = data_cloned.len(); + let user_message = String::from_utf8(udp_pkt.data); + match user_message { + Ok(message) => { + print!("[USER] {}", message); + if message == "BYE" || message == "BYE\n" { + break; + } + } + Err(err) => println!("[USER-ERR] {:?}", err), + } + + // send back a copy of the packet ("echo") + let ip_layer_res = udp_pkt.ip_packet; + let eth_layer_res = ip_layer_res.ethernet_packet; + let eth_layer = EthernetPacket::gen( + eth_layer_res.src_mac.val(), + eth_layer_res.dest_mac.val(), + ethernet::EthType::IPv4, + ); + let udp_size = UDPPacket::packet_size() + data_cloned_len as u16; + let ip_layer = IPPacket::gen( + eth_layer, + udp_size, + Protocol::UDP, + ip_layer_res.destination_ip.val(), + ip_layer_res.source_ip.val(), + ); + let mut udp_layer = UDPPacket::gen( + ip_layer, + udp_pkt.dest_port.val(), + udp_pkt.src_port.val(), + data_cloned_len as u16, + ); + udp_layer.data = data_cloned; + let data_2_send = udp_layer.serialize(); + let start_udp = + data_2_send.len() - (UDPPacket::packet_size() as usize + data_cloned_len); + let start_ip = start_udp - (IPPacket::packet_size() as usize); + udp_layer + .ip_packet + .calculate_checksum(&data_2_send[start_ip..start_udp]); + udp_layer.calculate_checksum(&data_2_send[start_udp..]); + let data_2_send_final = udp_layer.serialize(); + disable_network_interrupts(); + NET_INFO + .lock() + .get_ref() + .unwrap() + .send_packet(&data_2_send_final); + enable_network_interrupts(); + } + println!("[INFO] Socket is closing"); + socket.close(); + } + Err(err) => println!("{:?}", err), + } +} From ffc4667543e20ae68b1e4990e7c96241b4501dc7 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 7 Nov 2023 18:33:12 -0500 Subject: [PATCH 08/36] Interrupt handler shorter with async processing --- kernel/src/main.rs | 20 +++-- kernel/src/network/init.rs | 27 ++++--- kernel/src/network/mod.rs | 4 +- kernel/src/network/processing.rs | 133 +++++++++++++++++++++++++++++++ kernel/src/network/raw_socket.rs | 91 +++++++++++++++++++++ kernel/src/network/rtl8139.rs | 83 ++----------------- kernel/src/network/socket.rs | 114 -------------------------- kernel/src/task/udp_echo.rs | 2 +- 8 files changed, 267 insertions(+), 207 deletions(-) create mode 100644 kernel/src/network/processing.rs create mode 100644 kernel/src/network/raw_socket.rs diff --git a/kernel/src/main.rs b/kernel/src/main.rs index d4042e6..acb1f9f 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -12,7 +12,7 @@ use core::panic::PanicInfo; use kernel::{ framebuffer, hlt_loop, network::{ - init::init_dhcp, + init::{init_dhcp, init_process_packet_data}, rtl8139::NET_INFO, }, println, @@ -54,6 +54,14 @@ async fn example_task() { println!("async number: {}", number); } +async fn do_init_dhcp() { + let status_init_dhcp = init_dhcp(4).await; + if !status_init_dhcp { + println!("[ERR] DHCP error -- whats my ip?"); + hlt_loop(); + } +} + fn kernel_main(boot_info: &'static mut BootInfo) -> ! { use kernel::allocator; use kernel::memory; @@ -80,22 +88,22 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { .unwrap() .init(&mut frame_allocator, phys_mem_offset) }; // so that the NET INFO gets released - let status_init_dhcp = init_dhcp(10); if !status_init { println!("[ERR] Cannot init RTL8139"); hlt_loop(); - } else if !status_init_dhcp { - println!("[ERR] DHCP error -- whats my ip?"); - hlt_loop(); } #[cfg(test)] test_main(); let mut executor = Executor::new(); + // Start the processing of pending packets + init_process_packet_data(&mut executor); + executor.spawn(Task::new(do_init_dhcp())); // not entirely async, will finish + executor.spawn(Task::new(keyboard::print_keypresses())); - executor.spawn(Task::new(udp_echo::udp_echo_server())); executor.spawn(Task::new(example_task())); + executor.spawn(Task::new(udp_echo::udp_echo_server())); executor.run(); } diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index 078c8cf..4f28a6a 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -1,18 +1,24 @@ +use futures_util::StreamExt; +use x86_64::instructions::hlt; + use super::constants::{BROADCAST_ADDR, BROADCAST_MAC, DHCP_CLIENT_PORT}; +use super::processing; use crate::network::dhcp::DHCPPacket; use crate::network::ethernet::{EthType, EthernetPacket}; use crate::network::ip::{IPPacket, Protocol}; use crate::network::layer::{Layer, LayerType}; use crate::network::rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}; -use crate::network::socket::RawSocket; +use crate::network::raw_socket::RawSocket; use crate::network::udp::UDPPacket; +use crate::task::Task; +use crate::task::executor::Executor; use crate::{network::constants::DHCP_SERVER_PORT, println}; pub fn init() { // todo bundle the init phases } -pub fn init_dhcp(wait_timeout: u8) -> bool { +pub async fn init_dhcp(wait_timeout: u8) -> bool { // open a socket (shouldn't fail since we are in the init process) let mut socket = RawSocket::new(DHCP_CLIENT_PORT).unwrap(); @@ -45,27 +51,32 @@ pub fn init_dhcp(wait_timeout: u8) -> bool { let mut timeout = 0; let pkt_data; loop { - if let Some(dhcp_res) = socket.get_packet_with_timeout(1) { + if let Some(dhcp_res) = socket.next().await { + println!("Found dhcp data"); pkt_data = dhcp_res; break; } else { + hlt(); disable_network_interrupts(); + { let rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); rtl_dev_info.send_packet(&packet_data); // send another packet + } enable_network_interrupts(); } timeout += 1; - if timeout == wait_timeout { + if timeout == wait_timeout * 18 { socket.close(); return false; } } - socket.close(); + disable_network_interrupts(); let mut rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_mut().unwrap(); + println!("Printing ip"); if pkt_data.get_type() == LayerType::DHCP { let dhcp_res = pkt_data.unwrap_dhcp(); rtl_dev_info.my_ip_address = Some(dhcp_res.my_ip.val()); @@ -77,8 +88,6 @@ pub fn init_dhcp(wait_timeout: u8) -> bool { true } -pub async fn process_packet_data() -> ! { - loop { - - } +pub fn init_process_packet_data(exec: &mut Executor) { + exec.spawn(Task::new(processing::init_packet_processing())); } \ No newline at end of file diff --git a/kernel/src/network/mod.rs b/kernel/src/network/mod.rs index 0c3ae54..2ad5eae 100644 --- a/kernel/src/network/mod.rs +++ b/kernel/src/network/mod.rs @@ -8,13 +8,15 @@ pub mod init; pub mod ip; pub mod layer; pub mod rtl8139; -pub mod socket; +pub mod raw_socket; pub mod udp; +pub mod socket; // todo: remove pub until things break... mod arp_table; pub mod constants; mod netsync; mod raw_array; +mod processing; // pub mod e1000; /* diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs new file mode 100644 index 0000000..6fc0367 --- /dev/null +++ b/kernel/src/network/processing.rs @@ -0,0 +1,133 @@ +use alloc::{vec::Vec, collections::VecDeque}; +use conquer_once::spin::OnceCell; +use crossbeam_queue::ArrayQueue; +use futures_util::{task::AtomicWaker, Stream, StreamExt}; + +use crate::{println, network::{raw_socket::wake_sockets, layer::{PacketData, Layer}, arp_table::ArpEntry, constants::BROADCAST_ADDR, arp::ArpPacket, ethernet::{EthernetPacket, EthType}, rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}}}; + +use core::{ + pin::Pin, + task::{Context, Poll}, +}; + +use super::layer::full_parse; + +static PROCESS_VEC_WAKER: AtomicWaker = AtomicWaker::new(); +static PENDING_DATA: OnceCell>> = OnceCell::uninit(); + +pub struct PendingProcessingStream { + _private: (), +} + +impl PendingProcessingStream { + pub fn new() -> Self { + PENDING_DATA + .try_init_once(|| ArrayQueue::new(100)) + .expect("PendingProcessingStream::new should only be called once"); + PendingProcessingStream { _private: () } + } +} + +impl Stream for PendingProcessingStream { + type Item = Vec; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll>> { + let queue = PENDING_DATA.try_get().expect("not initialized"); + let data = queue.pop(); + PROCESS_VEC_WAKER.register(cx.waker()); + match data { + Some(pkt_data) => { + Poll::Ready(Some(pkt_data)) + } + None => Poll::Pending, + } + } +} + +pub(crate) fn add_pkt_data(data: Vec) { + if let Ok(queue) = PENDING_DATA.try_get() { + if queue.push(data).is_err() { + println!("[WARN] packet queue full; dropping packet"); + } else { + PROCESS_VEC_WAKER.wake(); + } + } else { + println!("[WARN] packet queue uninitialized"); + } +} + +pub async fn init_packet_processing() { + let mut raw_packets = PendingProcessingStream::new(); + while let Some(pkt_data) = raw_packets.next().await { + let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); + assert!(amount_parsed_and_pkt.0 == pkt_data.len() || pkt_data.len() < 64); + // Try to get the device info + disable_network_interrupts(); + let mut net_dev = NET_INFO.lock(); + // Get the device fields + let rtl_dev_info = net_dev.get_mut().unwrap(); + match amount_parsed_and_pkt.1 { + PacketData::ARP(arp) => { + // todo: also check for broadcast + if arp.recp_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { + // println!("[INT-HANDLER] Send a response back"); + let eth_layer = EthernetPacket::gen( + arp.sender_mac.val(), + rtl_dev_info.mac_address.unwrap(), + EthType::Arp, + ); + let arp_layer = ArpPacket::gen( + eth_layer, + rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR), + arp.sender_ip.val(), + false, + ); + let arp_pkt = arp_layer.serialize(); + rtl_dev_info.send_packet(&arp_pkt); + } else { + // println!("[INT-HANDLER] Receiving arp reply"); + // todo: expire from arp table? + rtl_dev_info.arp_table.push(ArpEntry { + mac: arp.sender_mac.val(), + ip: arp.sender_ip.val(), + expires: 0, + }); + } + } + PacketData::DHCP(dhcp) => { + let dst_port = dhcp.udp_packet.dest_port.val(); + println!("[HANDLER] Found DHCP packet"); + if rtl_dev_info.open_ports.contains(&dst_port) { + println!("[HANDLER] Port {} is open", dst_port); + // if we are listening on the port, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&dst_port) { + rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + } + rtl_dev_info + .ports + .get_mut(&dst_port) + .unwrap() + .push_back(PacketData::DHCP(dhcp)); + wake_sockets(dst_port); + } + } + PacketData::UDP(udp) => { + let dst_port = udp.dest_port.val(); + if rtl_dev_info.open_ports.contains(&dst_port) { + // if we are listening on the port, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&dst_port) { + rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + } + rtl_dev_info + .ports + .get_mut(&dst_port) + .unwrap() + .push_back(PacketData::UDP(udp)); + wake_sockets(dst_port); + } + } + _ => {} // ignore others + } + enable_network_interrupts(); + } +} diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs new file mode 100644 index 0000000..56b0e2a --- /dev/null +++ b/kernel/src/network/raw_socket.rs @@ -0,0 +1,91 @@ +use futures_util::{Stream, task::AtomicWaker}; +use hashbrown::HashMap; + +use super::{ + layer::PacketData, + rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, +}; +use core::task::{Context, Poll}; +use lazy_static::lazy_static; + +lazy_static! { + pub static ref NEW_PACKET_WAKER: spin::Mutex> = spin::Mutex::new(HashMap::new()); +} + +#[derive(Debug)] +pub enum NetworkErrors { + PortInUse, +} + +// todo: Implement a socket for user-space... + +pub struct RawSocket { + port: u16, +} + +impl RawSocket { + pub fn new(port: u16) -> Result { + disable_network_interrupts(); + let mut rtl_dev_info_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // Check if the port is in use + if rtl_dev_info.open_ports.contains(&port) { + enable_network_interrupts(); + return Err(NetworkErrors::PortInUse); + } + // If not then bind to it + rtl_dev_info.open_ports.insert(port); + // and allocate a waker + NEW_PACKET_WAKER.lock().insert(port, AtomicWaker::new()); + enable_network_interrupts(); + Ok(RawSocket { port }) + } + + fn try_get_packet_inner(&self) -> Option { + let mut rtl_dev_info_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + match rtl_dev_info.ports.get_mut(&self.port) { + Some(vec) => vec.pop_front(), + None => None, + } + } + + pub fn close(&mut self) { + disable_network_interrupts(); + let mut rtl_dev_info_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // close the port so that we don't receive anymore packets + rtl_dev_info.open_ports.remove(&self.port); + // Remove the waker, since the port has no listeners + NEW_PACKET_WAKER.lock().remove(&self.port); + // Try to clear all the pending packets from the port + if rtl_dev_info.ports.contains_key(&self.port) { + let vec = rtl_dev_info.ports.get_mut(&self.port); + vec.unwrap().clear(); + } + enable_network_interrupts(); + } +} + +impl Stream for RawSocket { + type Item = PacketData; + + fn poll_next(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + disable_network_interrupts(); + let pkt = self.try_get_packet_inner(); + enable_network_interrupts(); + // listen on the port with the waker + NEW_PACKET_WAKER.lock()[&self.port].register(cx.waker()); + if pkt.is_some() { + Poll::Ready(pkt) + } else { + Poll::Pending + } + } +} + +/// Wake sockets by port +pub(crate) fn wake_sockets(port: u16) { + // wake the port up + NEW_PACKET_WAKER.lock()[&port].wake(); +} \ No newline at end of file diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 879f355..f98aa6c 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -22,7 +22,8 @@ use crate::network::raw_array::WrappingRawArray; use super::constants::{ BROADCAST_ADDR, INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG, }; -use super::socket::wake_sockets; +use super::processing::add_pkt_data; +use super::raw_socket::wake_sockets; use super::{ arp_table::ArpEntry, devices::{Device, PCIClassCodes}, @@ -111,77 +112,12 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) let status = unsafe { port_isr.read() }; // Reset the ISR register unsafe { port_isr.write(0x05) }; - // println!("!! {} !!", status); if status & TOK != 0x0 { // Sent - // println!("Sending packet"); } if status & ROK != 0x0 { - // println!("Receiving packet"); // Received packet - let pkt = recv_packet(rtl_dev_info); - match pkt { - PacketData::ARP(arp) => { - // todo: also check for broadcast - if arp.recp_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { - // println!("[INT-HANDLER] Send a response back"); - let eth_layer = EthernetPacket::gen( - arp.sender_mac.val(), - rtl_dev_info.mac_address.unwrap(), - EthType::Arp, - ); - let arp_layer = ArpPacket::gen( - eth_layer, - rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR), - arp.sender_ip.val(), - false, - ); - let arp_pkt = arp_layer.serialize(); - rtl_dev_info.send_packet(&arp_pkt); - } else { - // println!("[INT-HANDLER] Receiving arp reply"); - // todo: expire from arp table? - rtl_dev_info.arp_table.push(ArpEntry { - mac: arp.sender_mac.val(), - ip: arp.sender_ip.val(), - expires: 0, - }); - } - } - PacketData::DHCP(dhcp) => { - let dst_port = dhcp.udp_packet.dest_port.val(); - // println!("[INT-HANDLER] Found DHCP packet"); - if rtl_dev_info.open_ports.contains(&dst_port) { - // println!("[INT-HANDLER] Port {} is open", dst_port); - // if we are listening on the port, try to insert it into the map - if !rtl_dev_info.ports.contains_key(&dst_port) { - rtl_dev_info.ports.insert(dst_port, VecDeque::new()); - } - rtl_dev_info - .ports - .get_mut(&dst_port) - .unwrap() - .push_back(PacketData::DHCP(dhcp)); - wake_sockets(); - } - } - PacketData::UDP(udp) => { - let dst_port = udp.dest_port.val(); - if rtl_dev_info.open_ports.contains(&dst_port) { - // if we are listening on the port, try to insert it into the map - if !rtl_dev_info.ports.contains_key(&dst_port) { - rtl_dev_info.ports.insert(dst_port, VecDeque::new()); - } - rtl_dev_info - .ports - .get_mut(&dst_port) - .unwrap() - .push_back(PacketData::UDP(udp)); - wake_sockets(); - } - } - _ => {} // ignore others - } + recv_packet(rtl_dev_info); } // Allow interrupts to the device @@ -193,8 +129,8 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) } } -// todo: refactor to be a loop, this function needs to return a list -fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { +// todo: refactor to be a loop, this function needs to process >1 packet +fn recv_packet(rtl_dev_info: &RTL8139) { if rtl_dev_info.recv_buffer.is_none() || rtl_dev_info.physical_mem_offset.is_none() { panic!("RTL8139 is not initialized properly"); } @@ -223,10 +159,7 @@ fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { let packet = rx_buffer.trim((length - 4) as usize); // ? throw out the crc... we don't need to check it... rx_buffer.shift_amount(4); - let amount_parsed_and_pkt = full_parse(packet.as_slice()); - - // the amount we parse will be equal to length unless we are under the minimum - assert!(amount_parsed_and_pkt.0 == (length - 4) as usize || length >= 64); + add_pkt_data(packet); // after receiving the packet, update CAPR and RECV_POS // increment recv_pos unsafe { @@ -243,7 +176,6 @@ fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { let mut capr = Port::::new((rtl_dev_info.config.io_base.unwrap() + CAPR) as u16); // println!("[RECV_POS] {}", unsafe { RECV_POS }); unsafe { capr.write(RECV_POS - 0x10) }; - return amount_parsed_and_pkt.1; } else { unsafe { RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; @@ -251,7 +183,6 @@ fn recv_packet(rtl_dev_info: &RTL8139) -> PacketData { // break; } } - PacketData::UNDEF(EmptyLayer::new()) } // TODO: Split the driver into separate bits so we can lock individual resources? @@ -266,7 +197,7 @@ pub struct RTL8139 { pub mac_address: Option, pub open_ports: HashSet, pub ports: HashMap>, - arp_table: Vec, + pub arp_table: Vec, } impl RTL8139 { diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 67da8f6..e69de29 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -1,114 +0,0 @@ -use futures_util::{Stream, task::AtomicWaker}; - -use super::{ - layer::PacketData, - rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, -}; -use core::task::{Context, Poll}; - -static WAKER: AtomicWaker = AtomicWaker::new(); - -#[derive(Debug)] -pub enum NetworkErrors { - PortInUse, -} - -// todo: Implement a socket for user-space... - -pub struct RawSocket { - port: u16, -} - -impl RawSocket { - pub fn new(port: u16) -> Result { - disable_network_interrupts(); - let mut rtl_dev_info_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); - // Check if the port is in use - if rtl_dev_info.open_ports.contains(&port) { - enable_network_interrupts(); - return Err(NetworkErrors::PortInUse); - } - // If not then bind to it - rtl_dev_info.open_ports.insert(port); - enable_network_interrupts(); - Ok(RawSocket { port }) - } - - fn try_get_packet_inner(&self) -> Option { - let mut rtl_dev_info_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); - match rtl_dev_info.ports.get_mut(&self.port) { - Some(vec) => vec.pop_front(), - None => None, - } - } - - // try to get a packet. Won't block, but will return Option - pub fn try_get_packet(&self) -> Option { - disable_network_interrupts(); - let pkt = self.try_get_packet_inner(); - enable_network_interrupts(); - pkt - } - - // Query a port for a packet. Will block until a packet arrives - pub fn get_packet(&self) -> PacketData { - let pkt; - loop { - x86_64::instructions::hlt(); - if let Some(next_pkt) = self.try_get_packet() { - pkt = next_pkt; - break; - } - } - pkt - } - - // Query a port for a packet. Will block until a packet arrives - pub fn get_packet_with_timeout(&self, timeout_s: u32) -> Option { - let mut pkt = None; - for _ in 0..(18 * timeout_s) { - pkt = self.try_get_packet(); - if pkt.is_some() { - break; - } - x86_64::instructions::hlt(); - } - pkt - } - - pub fn close(&mut self) { - disable_network_interrupts(); - let mut rtl_dev_info_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); - // close the port so that we don't receive anymore packets - rtl_dev_info.open_ports.remove(&self.port); - // Try to clear all the pending packets from the port - if rtl_dev_info.ports.contains_key(&self.port) { - let vec = rtl_dev_info.ports.get_mut(&self.port); - vec.unwrap().clear(); - } - enable_network_interrupts(); - } -} - -impl Stream for RawSocket { - type Item = PacketData; - - fn poll_next(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - disable_network_interrupts(); - let pkt = self.try_get_packet_inner(); - enable_network_interrupts(); - WAKER.register(cx.waker()); - if pkt.is_some() { - Poll::Ready(pkt) - } else { - Poll::Pending - } - } -} - -pub(crate) fn wake_sockets() { - WAKER.wake(); -} \ No newline at end of file diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index 3af7c83..b756610 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -4,7 +4,7 @@ use crate::{ ip::{IPPacket, Protocol}, layer::{HasChecksum, Layer, LayerType}, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, - socket::RawSocket, + raw_socket::RawSocket, udp::UDPPacket, }, print, println, From b502a3bf3234f948589d2bbf8d5b1bb25c287189 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 7 Nov 2023 20:45:57 -0500 Subject: [PATCH 09/36] Socket API (missing timeout + TCP) --- kernel/src/main.rs | 6 +- kernel/src/network/constants.rs | 5 +- kernel/src/network/init.rs | 6 +- kernel/src/network/processing.rs | 19 ++- kernel/src/network/raw_socket.rs | 14 +- kernel/src/network/rtl8139.rs | 30 +--- kernel/src/network/socket.rs | 265 +++++++++++++++++++++++++++++++ kernel/src/task/executor.rs | 16 +- kernel/src/task/udp_echo.rs | 95 +++-------- 9 files changed, 338 insertions(+), 118 deletions(-) diff --git a/kernel/src/main.rs b/kernel/src/main.rs index acb1f9f..d8e18d2 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -99,10 +99,10 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { let mut executor = Executor::new(); // Start the processing of pending packets init_process_packet_data(&mut executor); - executor.spawn(Task::new(do_init_dhcp())); // not entirely async, will finish - - executor.spawn(Task::new(keyboard::print_keypresses())); + executor.spawn(Task::new(do_init_dhcp())); // not entirely async, will finish before others are run + executor.wait(); executor.spawn(Task::new(example_task())); + executor.spawn(Task::new(keyboard::print_keypresses())); executor.spawn(Task::new(udp_echo::udp_echo_server())); executor.run(); } diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index 9d8addc..348bcbc 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -3,8 +3,9 @@ pub const BROADCAST_ADDR: u32 = 0xFFFFFFFF; pub const BROADCAST_MAC: u64 = 0xFFFFFFFFFFFF; // Common port numbers -pub const DHCP_CLIENT_PORT: u16 = 68; -pub const DHCP_SERVER_PORT: u16 = 67; +pub const DHCP_CLIENT_PORT: u32 = 68; +pub const DHCP_SERVER_PORT: u32 = 67; +pub const ARP_PORT: u32 = u16::MAX as u32 + 2; // we are using ports above u16 to allow for extended our open "ports" map // RTL device specific info pub const RTL_VEND: u32 = 0x10EC; // Vendor ID for the rtl8139 diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index 4f28a6a..b6cb7d6 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -36,8 +36,8 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { let ip = IPPacket::gen(eth, ip_size, Protocol::UDP, 0x0, BROADCAST_ADDR); let udp = UDPPacket::gen( ip, - DHCP_CLIENT_PORT, - DHCP_SERVER_PORT, + DHCP_CLIENT_PORT as u16, + DHCP_SERVER_PORT as u16, DHCPPacket::packet_size(), ); let dhcp = DHCPPacket::gen(udp, None, rtl_dev_info.mac_address.unwrap()); @@ -52,7 +52,6 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { let pkt_data; loop { if let Some(dhcp_res) = socket.next().await { - println!("Found dhcp data"); pkt_data = dhcp_res; break; } else { @@ -76,7 +75,6 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { disable_network_interrupts(); let mut rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_mut().unwrap(); - println!("Printing ip"); if pkt_data.get_type() == LayerType::DHCP { let dhcp_res = pkt_data.unwrap_dhcp(); rtl_dev_info.my_ip_address = Some(dhcp_res.my_ip.val()); diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 6fc0367..1b121a2 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -3,7 +3,7 @@ use conquer_once::spin::OnceCell; use crossbeam_queue::ArrayQueue; use futures_util::{task::AtomicWaker, Stream, StreamExt}; -use crate::{println, network::{raw_socket::wake_sockets, layer::{PacketData, Layer}, arp_table::ArpEntry, constants::BROADCAST_ADDR, arp::ArpPacket, ethernet::{EthernetPacket, EthType}, rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}}}; +use crate::{println, network::{raw_socket::wake_sockets, layer::{PacketData, Layer}, arp_table::ArpEntry, constants::{BROADCAST_ADDR, ARP_PORT}, arp::ArpPacket, ethernet::{EthernetPacket, EthType}, rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}}}; use core::{ pin::Pin, @@ -85,17 +85,28 @@ pub async fn init_packet_processing() { let arp_pkt = arp_layer.serialize(); rtl_dev_info.send_packet(&arp_pkt); } else { - // println!("[INT-HANDLER] Receiving arp reply"); // todo: expire from arp table? rtl_dev_info.arp_table.push(ArpEntry { mac: arp.sender_mac.val(), ip: arp.sender_ip.val(), expires: 0, }); + if rtl_dev_info.open_ports.contains(&ARP_PORT) { + // if we are listening on the port, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&ARP_PORT) { + rtl_dev_info.ports.insert(ARP_PORT, VecDeque::new()); + } + rtl_dev_info + .ports + .get_mut(&ARP_PORT) + .unwrap() + .push_back(PacketData::ARP(arp)); + wake_sockets(ARP_PORT); + } } } PacketData::DHCP(dhcp) => { - let dst_port = dhcp.udp_packet.dest_port.val(); + let dst_port = dhcp.udp_packet.dest_port.val() as u32; println!("[HANDLER] Found DHCP packet"); if rtl_dev_info.open_ports.contains(&dst_port) { println!("[HANDLER] Port {} is open", dst_port); @@ -112,7 +123,7 @@ pub async fn init_packet_processing() { } } PacketData::UDP(udp) => { - let dst_port = udp.dest_port.val(); + let dst_port = udp.dest_port.val() as u32; if rtl_dev_info.open_ports.contains(&dst_port) { // if we are listening on the port, try to insert it into the map if !rtl_dev_info.ports.contains_key(&dst_port) { diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index 56b0e2a..8a705ce 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -9,22 +9,25 @@ use core::task::{Context, Poll}; use lazy_static::lazy_static; lazy_static! { - pub static ref NEW_PACKET_WAKER: spin::Mutex> = spin::Mutex::new(HashMap::new()); + pub static ref NEW_PACKET_WAKER: spin::Mutex> = spin::Mutex::new(HashMap::new()); } #[derive(Debug)] pub enum NetworkErrors { PortInUse, + NoAvailablePort, + NonexistentHost, + SocketInServerMode, } // todo: Implement a socket for user-space... pub struct RawSocket { - port: u16, + port: u32, } impl RawSocket { - pub fn new(port: u16) -> Result { + pub fn new(port: u32) -> Result { disable_network_interrupts(); let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); @@ -50,7 +53,7 @@ impl RawSocket { } } - pub fn close(&mut self) { + pub fn close(self) { disable_network_interrupts(); let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); @@ -70,6 +73,7 @@ impl RawSocket { impl Stream for RawSocket { type Item = PacketData; + // todo: What happens if we have packet loss? Won't we deadlock because we will be pending and never get a notification? fn poll_next(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { disable_network_interrupts(); let pkt = self.try_get_packet_inner(); @@ -85,7 +89,7 @@ impl Stream for RawSocket { } /// Wake sockets by port -pub(crate) fn wake_sockets(port: u16) { +pub(crate) fn wake_sockets(port: u32) { // wake the port up NEW_PACKET_WAKER.lock()[&port].wake(); } \ No newline at end of file diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index f98aa6c..4752d84 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -20,24 +20,20 @@ use crate::network::constants::{ use crate::network::raw_array::WrappingRawArray; use super::constants::{ - BROADCAST_ADDR, INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG, + INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG, }; use super::processing::add_pkt_data; -use super::raw_socket::wake_sockets; use super::{ arp_table::ArpEntry, devices::{Device, PCIClassCodes}, - layer::{full_parse, EmptyLayer, PacketData}, + layer::PacketData, netsync::InterruptCounter, }; use crate::{ interrupts::{InterruptHandler, PICS}, memory::BootInfoFrameAllocator, network::{ - arp::ArpPacket, devices, - ethernet::{EthType, EthernetPacket}, - layer::Layer, netsync::SafeRTL8139, }, println, @@ -195,8 +191,8 @@ pub struct RTL8139 { pub my_ip_address: Option, pub dhcp_server_ip: Option, pub mac_address: Option, - pub open_ports: HashSet, - pub ports: HashMap>, + pub open_ports: HashSet, + pub ports: HashMap>, pub arp_table: Vec, } @@ -375,21 +371,6 @@ impl RTL8139 { true } - // todo: remove this from driver code (add to socket.rs) - pub fn get_mac_from_ip(&self, ip: u32) -> u64 { - for entry in self.arp_table.iter() { - // todo: check for expired arps - if entry.ip == ip { - return entry.mac; - } - } - // send arp packet - - // wait for response - // recursively try again - 0 - } - pub fn send_packet(&self, packet_data: &Vec) { if self.send_buffer.is_none() || self.physical_mem_offset.is_none() { panic!("RTL8139 is not initialized properly"); @@ -425,6 +406,3 @@ impl RTL8139 { }; } } - -// sudo qemu-system-x86_64 -M q35 -serial mon:stdio -nographic -netdev vmnet-bridged,id=net0,ifname=en0 -device rtl8139,netdev=net0,mac=00:11:22:33:44:55 -// diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index e69de29..d9d9fbd 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -0,0 +1,265 @@ +use alloc::vec; +use alloc::vec::Vec; +use futures_util::StreamExt; +use x86_64::instructions::hlt; + +use crate::network::layer::LayerType; + +use super::{ + arp::ArpPacket, + constants::{ARP_PORT, BROADCAST_MAC}, + ethernet::{self, EthType, EthernetPacket}, + ip::{IPPacket, Protocol}, + layer::{HasChecksum, Layer, PacketData}, + raw_socket::{NetworkErrors, RawSocket}, + rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, + udp::UDPPacket, +}; + +pub struct NetworkQuery {} + +impl NetworkQuery { + pub async fn get_mac_from_ip(wait_timeout: u32, ip: u32) -> Option { + disable_network_interrupts(); + let mut rtl_dev_info_locked = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); + for entry in rtl_dev_info.arp_table.iter() { + // todo: check for expired arps + if entry.ip == ip { + return Some(entry.mac); + } + } + // send arp packet + let eth_layer = EthernetPacket::gen( + BROADCAST_MAC, + rtl_dev_info.mac_address.unwrap(), + EthType::Arp, + ); + let arp_layer = ArpPacket::gen(eth_layer, rtl_dev_info.my_ip_address.unwrap(), ip, true); + rtl_dev_info.send_packet(&arp_layer.serialize()); + drop(rtl_dev_info_locked); + enable_network_interrupts(); + + // wait for response + let mut socket = RawSocket::new(ARP_PORT).unwrap(); + let mut timeout = 0; + loop { + if let Some(pkt) = socket.next().await { + if pkt.get_type() != LayerType::ARP { + continue; + } + let arp_pkt = pkt.unwrap_arp(); + return Some(arp_pkt.sender_mac.val()); + } else { + hlt(); + disable_network_interrupts(); + { + let rtl_dev_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); + rtl_dev_info.send_packet(&arp_layer.serialize()); + } + enable_network_interrupts(); + } + timeout += 1; + if timeout == wait_timeout * 18 { + socket.close(); + return None; + } + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum SocketType { + UDP, + TCP, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +enum SocketState { + Listening, + Ready, +} + +pub struct Socket { + socket_type: SocketType, + socket_state: SocketState, + raw_socket: RawSocket, + dest_port: u16, + dest_address: u32, + dest_mac: u64, + pub src_port: u16, + src_address: u32, + src_mac: u64, +} + +impl Socket { + // Can't send yet + pub async fn open(socket_type: SocketType, src_port: u16) -> Result { + let mut chosen_src_port = src_port; + disable_network_interrupts(); + let rtl_dev_info_locked = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_locked.get_ref().unwrap(); + let src_mac = rtl_dev_info.mac_address.unwrap(); + let src_address = rtl_dev_info.my_ip_address.unwrap(); + if src_port == 0 { + let open_ports = &rtl_dev_info.open_ports; + // Linear probe to see if any ports are open + for i in 1000..u16::MAX { + if !open_ports.contains(&(i as u32)) { + chosen_src_port = i; + break; + } + } + } + drop(rtl_dev_info_locked); + enable_network_interrupts(); + if chosen_src_port == 0 { + return Err(NetworkErrors::NoAvailablePort); + } + let raw_socket = RawSocket::new(chosen_src_port as u32); + match raw_socket { + Ok(socket) => Ok(Socket { + socket_type, + raw_socket: socket, + socket_state: SocketState::Listening, + dest_port: 0, + dest_address: 0, + dest_mac: 0, + src_port, + src_address, + src_mac, + }), + Err(err) => Err(err), + } + } + + // Will listen for new connections and create new sessions + // UDP can only listen for one connection (and thus will return none). + pub async fn listen(&mut self) -> Option { + loop { + if let Some(pkt) = self.raw_socket.next().await { + if pkt.get_type() != LayerType::UDP { + continue; + } + let udp_pkt = pkt.unwrap_udp(); + self.dest_port = udp_pkt.src_port.val(); + self.dest_mac = udp_pkt.ip_packet.ethernet_packet.src_mac.val(); + self.dest_address = udp_pkt.ip_packet.source_ip.val(); + self.socket_state = SocketState::Ready; + disable_network_interrupts(); + let mut rtl_dev_info_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // re-enqueue the packet + if let Some(vec) = rtl_dev_info.ports.get_mut(&(self.src_port as u32)) { + vec.push_front(PacketData::UDP(udp_pkt)); + } + enable_network_interrupts(); + return None; + } + } + } + + pub async fn connect( + socket_type: SocketType, + dest_address: u32, + dest_port: u16, + src_port: u16, + ) -> Result { + let mut chosen_src_port = src_port; + disable_network_interrupts(); + let rtl_dev_info_locked = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_locked.get_ref().unwrap(); + let src_mac = rtl_dev_info.mac_address.unwrap(); + let src_address = rtl_dev_info.my_ip_address.unwrap(); + if src_port == 0 { + let open_ports = &rtl_dev_info.open_ports; + // Linear probe to see if any ports are open + for i in 1000..u16::MAX { + if !open_ports.contains(&(i as u32)) { + chosen_src_port = i; + break; + } + } + } + let dest_mac = NetworkQuery::get_mac_from_ip(10, dest_address).await; + drop(rtl_dev_info_locked); + enable_network_interrupts(); + if dest_mac.is_none() { + return Err(NetworkErrors::NonexistentHost); + } + if chosen_src_port == 0 { + return Err(NetworkErrors::NoAvailablePort); + } + let raw_socket = RawSocket::new(chosen_src_port as u32); + match raw_socket { + Ok(socket) => Ok(Socket { + socket_type, + socket_state: SocketState::Ready, + raw_socket: socket, + dest_port, + dest_address, + dest_mac: dest_mac.unwrap(), + src_port, + src_address, + src_mac, + }), + Err(err) => Err(err), + } + } + + pub fn close(self) { + self.raw_socket.close(); + } + + pub async fn read(&mut self, size: u32) -> Result, NetworkErrors> { + assert!(self.socket_type == SocketType::UDP); + if self.socket_state == SocketState::Listening { + return Err(NetworkErrors::SocketInServerMode); + } + if let Some(pkt) = self.raw_socket.next().await { + if pkt.get_type() != LayerType::UDP { + return Ok(vec![]); + } + let udp_pkt = pkt.unwrap_udp(); + return Ok(udp_pkt.data); + } + Ok(vec![]) + } + + pub fn write(&self, data: Vec) -> Result<(), NetworkErrors> { + assert!(self.socket_type == SocketType::UDP); + if self.socket_state == SocketState::Listening { + return Err(NetworkErrors::SocketInServerMode); + } + let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); + let udp_size = UDPPacket::packet_size() + data.len() as u16; + let ip_layer = IPPacket::gen( + eth_layer, + udp_size, + Protocol::UDP, + self.src_address, + self.dest_address, + ); + let data_len = data.len(); + let mut udp_layer = + UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); + udp_layer.data = data; + let data_2_send = udp_layer.serialize(); + let start_udp = data_2_send.len() - (UDPPacket::packet_size() as usize + data_len); + let start_ip = start_udp - (IPPacket::packet_size() as usize); + udp_layer + .ip_packet + .calculate_checksum(&data_2_send[start_ip..start_udp]); + udp_layer.calculate_checksum(&data_2_send[start_udp..]); + let data_2_send_final = udp_layer.serialize(); + disable_network_interrupts(); + NET_INFO + .lock() + .get_ref() + .unwrap() + .send_packet(&data_2_send_final); + enable_network_interrupts(); + Ok(()) + } +} diff --git a/kernel/src/task/executor.rs b/kernel/src/task/executor.rs index 2b6e1fc..2785ec7 100644 --- a/kernel/src/task/executor.rs +++ b/kernel/src/task/executor.rs @@ -2,6 +2,7 @@ use super::{Task, TaskId}; use alloc::{collections::BTreeMap, sync::Arc, task::Wake}; use core::task::{Context, Poll, Waker}; use crossbeam_queue::ArrayQueue; +use x86_64::instructions::interrupts::{self, enable_and_hlt}; pub struct Executor { tasks: BTreeMap, @@ -33,9 +34,20 @@ impl Executor { } } - fn sleep_if_idle(&self) { - use x86_64::instructions::interrupts::{self, enable_and_hlt}; + // Wait until all tasks finish + pub fn wait(&mut self) { + loop { + self.run_ready_tasks(); + interrupts::disable(); + if self.task_queue.is_empty() { + return; + } else { + interrupts::enable(); + } + } + } + fn sleep_if_idle(&self) { interrupts::disable(); if self.task_queue.is_empty() { enable_and_hlt(); diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index b756610..f06a53f 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -1,82 +1,33 @@ use crate::{ - network::{ - ethernet::{self, EthernetPacket}, - ip::{IPPacket, Protocol}, - layer::{HasChecksum, Layer, LayerType}, - rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, - raw_socket::RawSocket, - udp::UDPPacket, - }, - print, println, + print, println, network::socket::{SocketType, Socket}, }; use alloc::string::String; -use futures_util::StreamExt; pub async fn udp_echo_server() { - let raw_socket = RawSocket::new(5554); - match raw_socket { + let socket_or_err = Socket::open(SocketType::UDP, 5554).await; + match socket_or_err { Ok(mut socket) => { - while let Some(pkt) = socket.next().await { - if pkt.get_type() != LayerType::UDP { - break; - } - let udp_pkt = pkt.unwrap_udp(); - let data_cloned = udp_pkt.data.clone(); - let data_cloned_len = data_cloned.len(); - let user_message = String::from_utf8(udp_pkt.data); - match user_message { - Ok(message) => { - print!("[USER] {}", message); - if message == "BYE" || message == "BYE\n" { - break; - } + // Listen for a single connection + socket.listen().await; + loop { + let data_or_err = socket.read(0).await; + if let Ok(data) = data_or_err { + let user_message = String::from_utf8(data.clone()); + match user_message { + Ok(message) => print!("[USER] {}", message), + Err(err) => println!("[USER-ERR] {:?}", err), + } + let res_or_err = socket.write(data); + if let Err(err) = res_or_err { + println!("[ERR] {:?}", err); + break; } - Err(err) => println!("[USER-ERR] {:?}", err), + } else if let Err(err) = data_or_err { + println!("[ERR] {:?}", err); + break; } - - // send back a copy of the packet ("echo") - let ip_layer_res = udp_pkt.ip_packet; - let eth_layer_res = ip_layer_res.ethernet_packet; - let eth_layer = EthernetPacket::gen( - eth_layer_res.src_mac.val(), - eth_layer_res.dest_mac.val(), - ethernet::EthType::IPv4, - ); - let udp_size = UDPPacket::packet_size() + data_cloned_len as u16; - let ip_layer = IPPacket::gen( - eth_layer, - udp_size, - Protocol::UDP, - ip_layer_res.destination_ip.val(), - ip_layer_res.source_ip.val(), - ); - let mut udp_layer = UDPPacket::gen( - ip_layer, - udp_pkt.dest_port.val(), - udp_pkt.src_port.val(), - data_cloned_len as u16, - ); - udp_layer.data = data_cloned; - let data_2_send = udp_layer.serialize(); - let start_udp = - data_2_send.len() - (UDPPacket::packet_size() as usize + data_cloned_len); - let start_ip = start_udp - (IPPacket::packet_size() as usize); - udp_layer - .ip_packet - .calculate_checksum(&data_2_send[start_ip..start_udp]); - udp_layer.calculate_checksum(&data_2_send[start_udp..]); - let data_2_send_final = udp_layer.serialize(); - disable_network_interrupts(); - NET_INFO - .lock() - .get_ref() - .unwrap() - .send_packet(&data_2_send_final); - enable_network_interrupts(); } - println!("[INFO] Socket is closing"); - socket.close(); - } - Err(err) => println!("{:?}", err), + }, + Err(err) => println!("[ERR] {:?}", err), } -} +} \ No newline at end of file From 199d24466ec4ff2c67a7678a2d93d1f4d29bc12c Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sat, 11 Nov 2023 19:59:53 -0500 Subject: [PATCH 10/36] Renaming, clone, and comments --- kernel/src/network/arp.rs | 10 +++++----- kernel/src/network/bytefield.rs | 1 + kernel/src/network/constants.rs | 8 ++++++++ kernel/src/network/dhcp.rs | 20 +++++++++----------- kernel/src/network/ethernet.rs | 2 +- kernel/src/network/ip.rs | 14 +++++++------- kernel/src/network/layer.rs | 30 +++++++++++++++++++++--------- kernel/src/network/mod.rs | 22 ++-------------------- kernel/src/network/netsync.rs | 5 +++++ kernel/src/network/udp.rs | 14 +++++++------- 10 files changed, 66 insertions(+), 60 deletions(-) diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs index 58f4559..23cb700 100644 --- a/kernel/src/network/arp.rs +++ b/kernel/src/network/arp.rs @@ -7,7 +7,7 @@ use alloc::vec; use alloc::vec::Vec; #[derive(Debug)] pub struct ArpPacket { - pub ethernet_packet: EthernetPacket, + pub eth: EthernetPacket, hardware_type: Bytefield16, protocol_type: Bytefield16, hardware_address_length: u8, @@ -23,7 +23,7 @@ impl ArpPacket { // Create an empty packet with all 0s pub fn new() -> Self { ArpPacket { - ethernet_packet: EthernetPacket::new(), + eth: EthernetPacket::new(), hardware_type: Bytefield16::new(0), protocol_type: Bytefield16::new(0), hardware_address_length: 0, @@ -41,7 +41,7 @@ impl ArpPacket { let sender_mac = eth_layer.src_mac; assert!(eth_layer.packet_type == EthType::Arp); ArpPacket { - ethernet_packet: eth_layer, + eth: eth_layer, hardware_type: Bytefield16::new(0x1), // ethernet protocol_type: Bytefield16::new(0x0800), // ipv4 hardware_address_length: 6, // ethernet is the value 6 @@ -62,7 +62,7 @@ impl Layer for ArpPacket { // Read ethernet packet and 28 bytes let mut i = 0; - packet.ethernet_packet = eth_layer; + packet.eth = eth_layer; packet.hardware_type = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.protocol_type = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); @@ -79,7 +79,7 @@ impl Layer for ArpPacket { fn serialize(&self) -> Vec { let mut res = vec![]; - res.extend(self.ethernet_packet.serialize()); + res.extend(self.eth.serialize()); res.extend(self.hardware_type.data); res.extend(self.protocol_type.data); res.push(self.hardware_address_length); diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs index 9ecb0dd..a1fdb65 100644 --- a/kernel/src/network/bytefield.rs +++ b/kernel/src/network/bytefield.rs @@ -58,6 +58,7 @@ macro_rules! bytefield_int { $t { data } } + // get the data in the natural endianness pub fn val(&self) -> $int { let mut res = 0; for i in 0..$size { diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index 348bcbc..584f21f 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -23,3 +23,11 @@ pub const CR_BUFE: u8 = 0x01; // Rx buffer is empty pub const CR: u32 = 0x37; // command register (1byte) pub const CAPR: u32 = 0x38; // current address of packet read (2byte, C mode, initial value 0xFFF0) pub const RX_READ_PTR_MASK: u16 = !0x3; // + +// TCP Constants +pub const TCP_FIN: u8 = 0x1; +pub const TCP_SYN: u8 = 0x2; +pub const TCP_RST: u8 = 0x4; +pub const TCP_PSH: u8 = 0x8; +pub const TCP_ACK: u8 = 0x10; +pub const TCP_URG: u8 = 0x20; \ No newline at end of file diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 6db687f..453165a 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -23,7 +23,7 @@ impl WrappedU32 { static mut ID_GEN: spin::Mutex = spin::Mutex::new(WrappedU32 { data: 0 }); #[derive(Debug)] pub struct DHCPPacket { - pub udp_packet: UDPPacket, // public for checksumming + pub udp: UDPPacket, // public for checksumming op_code: u8, hardware_type: u8, hardware_address_length: u8, @@ -44,7 +44,7 @@ pub struct DHCPPacket { impl DHCPPacket { pub fn new() -> Self { DHCPPacket { - udp_packet: UDPPacket::new(), + udp: UDPPacket::new(), op_code: 0, // 1 byte hardware_type: 0, // 1 byte hardware_address_length: 0, // 1 byte @@ -77,7 +77,7 @@ impl DHCPPacket { client_hardware_address[i] = mac[i]; } let mut dhcp = DHCPPacket { - udp_packet, + udp: udp_packet, op_code: 1, // 1 for is request hardware_type: 1, // ethernet is 1 hardware_address_length: 6, // corresponds with hardware address length @@ -99,10 +99,8 @@ impl DHCPPacket { let start_udp = data.len() - (DHCPPacket::packet_size() as usize + UDPPacket::packet_size() as usize); let start_ip = start_udp - (IPPacket::packet_size() as usize); - dhcp.udp_packet - .ip_packet - .calculate_checksum(&data[start_ip..start_udp]); - dhcp.udp_packet.calculate_checksum(&data[start_udp..]); + dhcp.udp.ip.calculate_checksum(&data[start_ip..start_udp]); + dhcp.udp.calculate_checksum(&data[start_udp..]); dhcp } } @@ -115,7 +113,7 @@ impl Layer for DHCPPacket { { let mut packet = DHCPPacket::new(); // create an empty packet let mut i = 0; - packet.udp_packet = udp_layer; + packet.udp = udp_layer; packet.op_code = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.hardware_type = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); @@ -132,7 +130,7 @@ impl Layer for DHCPPacket { i += 128; // file i += 64; // options // Ignoring those parts of the DHCP packet. - let left_to_parse = packet.udp_packet.length.val() - 308; + let left_to_parse = packet.udp.length.val() - 308; i += left_to_parse as usize; assert!(i >= 300); // 300 bytes (packet, i, LayerType::UNDEF) @@ -140,7 +138,7 @@ impl Layer for DHCPPacket { fn serialize(&self) -> alloc::vec::Vec { let mut res = vec![]; - res.extend(self.udp_packet.serialize()); + res.extend(self.udp.serialize()); res.push(self.op_code); res.push(self.hardware_type); res.push(self.hardware_address_length); @@ -156,7 +154,7 @@ impl Layer for DHCPPacket { res.extend(self.sname); res.extend(self.file); res.extend(self.options); - assert!(res.len() == (300 + self.udp_packet.serialize().len())); + assert!(res.len() == (300 + self.udp.serialize().len())); res } diff --git a/kernel/src/network/ethernet.rs b/kernel/src/network/ethernet.rs index 9e30442..1c80526 100644 --- a/kernel/src/network/ethernet.rs +++ b/kernel/src/network/ethernet.rs @@ -29,7 +29,7 @@ impl EthType { } // Total size is 14 bytes -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct EthernetPacket { pub dest_mac: Bytefield48, // u48 pub src_mac: Bytefield48, // u48, diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index ca0ba0d..b1c4629 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -42,9 +42,9 @@ impl WrappedU16 { } static mut ID_GEN: spin::Mutex = spin::Mutex::new(WrappedU16 { data: 0 }); -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct IPPacket { - pub ethernet_packet: EthernetPacket, + pub eth: EthernetPacket, version_hlen: u8, // 1 byte type_of_service: u8, // 1 byte pub total_length: Bytefield16, // 2 bytes (public for checksumming) @@ -61,7 +61,7 @@ pub struct IPPacket { impl IPPacket { pub fn new() -> Self { IPPacket { - ethernet_packet: EthernetPacket::new(), + eth: EthernetPacket::new(), version_hlen: 0, type_of_service: 0, total_length: Bytefield16::new(0), @@ -89,7 +89,7 @@ impl IPPacket { Bytefield16::new(id_gen.get()) }; IPPacket { - ethernet_packet, + eth: ethernet_packet, version_hlen: 0x45, type_of_service: 0x0, total_length: Bytefield16::new(data_length + 20), // adding data length and size of IP packet @@ -113,7 +113,7 @@ impl Layer for IPPacket { let mut packet = IPPacket::new(); // create an empty packet // Read 20 bytes let mut i = 0; - packet.ethernet_packet = ethernet_layer; + packet.eth = ethernet_layer; packet.version_hlen = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.type_of_service = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.total_length = Bytefield16::read_inc(&bytevec[i..], &mut i); @@ -138,7 +138,7 @@ impl Layer for IPPacket { fn serialize(&self) -> Vec { let mut res = vec![]; - res.extend(self.ethernet_packet.serialize()); + res.extend(self.eth.serialize()); res.push(self.version_hlen); res.push(self.type_of_service); res.extend(self.total_length.data); @@ -149,7 +149,7 @@ impl Layer for IPPacket { res.extend(self.checksum.data); res.extend(self.source_ip.data); res.extend(self.destination_ip.data); - assert!(res.len() == (20 + self.ethernet_packet.serialize().len())); + assert!(res.len() == (20 + self.eth.serialize().len())); res } diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index b8828e6..90f4998 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -5,6 +5,7 @@ use super::arp::ArpPacket; use super::dhcp::DHCPPacket; use super::ethernet::EthernetPacket; use super::ip::IPPacket; +use super::tcp::TCPPacket; use super::udp::UDPPacket; pub trait Layer { @@ -73,7 +74,7 @@ pub enum PacketData { UDP(UDPPacket), ICMP(EmptyLayer), DHCP(DHCPPacket), - TCP(EmptyLayer), + TCP(TCPPacket), UNDEF(EmptyLayer), } @@ -114,6 +115,12 @@ impl PacketData { _ => unreachable!("Mismatched type. Couldn't unwrap"), } } + pub fn unwrap_tcp(self) -> TCPPacket { + match self { + PacketData::TCP(val) => val, + _ => unreachable!("Mismatched type. Couldn't unwrap"), + } + } pub fn get_type(&self) -> LayerType { match self { PacketData::ETH(_) => LayerType::ETH, @@ -141,7 +148,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { last_layer = PacketData::ETH(eth_layer); i += size; next_type = network_layer_type; - } + }, LayerType::IP => { let last_layer_data = last_layer.unwrap_eth(); let (ip_layer, size, transport_layer_type) = @@ -149,7 +156,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { last_layer = PacketData::IP(ip_layer); i += size; next_type = transport_layer_type; - } + }, LayerType::ARP => { let last_layer_data = last_layer.unwrap_eth(); let (arp_layer, size, transport_layer_type) = @@ -157,7 +164,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { last_layer = PacketData::ARP(arp_layer); i += size; next_type = transport_layer_type; - } + }, LayerType::UDP => { let last_layer_data = last_layer.unwrap_ip(); let (udp_layer, size, application_layer_type) = @@ -165,10 +172,10 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { last_layer = PacketData::UDP(udp_layer); i += size; next_type = application_layer_type; - } + }, LayerType::ICMP => { return (0, PacketData::UNDEF(EmptyLayer::new())); - } + }, LayerType::DHCP => { let last_layer_data = last_layer.unwrap_udp(); let (dhcp_layer, size, empty_type) = @@ -176,10 +183,15 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { last_layer = PacketData::DHCP(dhcp_layer); i += size; next_type = empty_type; - } + }, LayerType::TCP => { - return (0, PacketData::UNDEF(EmptyLayer::new())); - } + let last_layer_data = last_layer.unwrap_ip(); + let (tcp_layer, size, empty_type) = + TCPPacket::parse(last_layer_data, &packet[i..]); + last_layer = PacketData::TCP(tcp_layer); + i += size; + next_type = empty_type; + }, LayerType::UNDEF => { return (i, last_layer); } diff --git a/kernel/src/network/mod.rs b/kernel/src/network/mod.rs index 2ad5eae..2ab317e 100644 --- a/kernel/src/network/mod.rs +++ b/kernel/src/network/mod.rs @@ -11,29 +11,11 @@ pub mod rtl8139; pub mod raw_socket; pub mod udp; pub mod socket; +pub mod tcp; +mod tcp_session; // todo: remove pub until things break... mod arp_table; pub mod constants; mod netsync; mod raw_array; mod processing; - -// pub mod e1000; -/* -pub struct NetworkIO { - -} - -impl NetworkIO { - // Create a new NetworkIO instance (to use all possible network drivers? IDK) - fn new() -> Self { - return NetworkIO { - - }; - } - - // Read data from the network card - fn read() -> () { - - } -}*/ diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs index 115f803..c8b6c9c 100644 --- a/kernel/src/network/netsync.rs +++ b/kernel/src/network/netsync.rs @@ -2,6 +2,10 @@ use spin::MutexGuard; use super::rtl8139::RTL8139; +struct InterruptGuard { + +} + pub struct NetworkInterruptsGuard<'a> { data: MutexGuard<'a, Option>, } @@ -33,6 +37,7 @@ impl SafeRTL8139 { pub fn new(data: spin::Mutex>) -> Self { Self { data } } + pub fn lock(&self) -> NetworkInterruptsGuard { // disable_network_interrupts(); return NetworkInterruptsGuard { diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index 95f4ebb..23c12c0 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -9,7 +9,7 @@ use alloc::vec::Vec; #[derive(Debug)] pub struct UDPPacket { - pub ip_packet: IPPacket, // public for checksumming + pub ip: IPPacket, // public for checksumming pub src_port: Bytefield16, // 2 bytes pub dest_port: Bytefield16, // 2 bytes pub length: Bytefield16, // 2 bytes @@ -21,7 +21,7 @@ pub struct UDPPacket { impl UDPPacket { pub fn new() -> Self { UDPPacket { - ip_packet: IPPacket::new(), + ip: IPPacket::new(), src_port: Bytefield16::new(0), dest_port: Bytefield16::new(0), length: Bytefield16::new(0), @@ -32,7 +32,7 @@ impl UDPPacket { pub fn gen(ip_packet: IPPacket, src_port: u16, dest_port: u16, length: u16) -> Self { UDPPacket { - ip_packet, + ip: ip_packet, src_port: Bytefield16::new(src_port), dest_port: Bytefield16::new(dest_port), length: Bytefield16::new(length + 8), // size of body + 8 bytes for UDP @@ -51,7 +51,7 @@ impl Layer for UDPPacket { let mut packet = UDPPacket::new(); // create an empty packet // Read 14 bytes let mut i = 0; - packet.ip_packet = ip_layer; + packet.ip = ip_layer; packet.src_port = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.dest_port = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.length = Bytefield16::read_inc(&bytevec[i..], &mut i); @@ -74,13 +74,13 @@ impl Layer for UDPPacket { fn serialize(&self) -> alloc::vec::Vec { let mut res = vec![]; - res.extend(self.ip_packet.serialize()); + res.extend(self.ip.serialize()); res.extend(self.src_port.data); res.extend(self.dest_port.data); res.extend(self.length.data); res.extend(self.checksum.data); res.extend(&self.data); // most of the time should be empty - assert!(res.len() == (8 + self.ip_packet.serialize().len() + self.data.len())); + assert!(res.len() == (8 + self.ip.serialize().len() + self.data.len())); res } @@ -97,7 +97,7 @@ impl HasChecksum for UDPPacket { let mut udp_len = self.length.swapped_endianness().val() as usize; // First we do the IP as a pseduo header - let ip = &self.ip_packet; + let ip = &self.ip; sum += (ip.source_ip.data[0] as u32) | (ip.source_ip.data[1] as u32) << 8; sum += (ip.source_ip.data[2] as u32) | (ip.source_ip.data[3] as u32) << 8; sum += (ip.destination_ip.data[0] as u32) | (ip.destination_ip.data[1] as u32) << 8; From 4895672f76be05e08f34a0432ebf256c0ecdcda0 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 12 Nov 2023 15:21:22 -0500 Subject: [PATCH 11/36] Timeout --- kernel/src/interrupts.rs | 3 +- kernel/src/main.rs | 3 +- kernel/src/network/README.md | 16 +++++- kernel/src/network/TODO.md | 4 +- kernel/src/task/timeout.rs | 95 ++++++++++++++++++++++++++++++++++++ kernel/src/task/udp_echo.rs | 22 +++++++-- 6 files changed, 134 insertions(+), 9 deletions(-) create mode 100644 kernel/src/task/timeout.rs diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index d986f3f..f0bdbe9 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -1,4 +1,4 @@ -use crate::{gdt, hlt_loop}; +use crate::{gdt, hlt_loop, task::timeout::poll_timeouts}; use lazy_static::lazy_static; use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame, PageFaultErrorCode}; use crate::println; @@ -117,6 +117,7 @@ extern "x86-interrupt" fn double_fault_handler( } extern "x86-interrupt" fn timer_interrupt_handler(_stack_frame: InterruptStackFrame) { + poll_timeouts(); unsafe { PICS.lock() .notify_end_of_interrupt(InterruptIndex::Timer.as_u8()); diff --git a/kernel/src/main.rs b/kernel/src/main.rs index d8e18d2..9542a98 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -17,7 +17,7 @@ use kernel::{ }, println, task::keyboard, - task::udp_echo, + task::{udp_echo, tcp_echo}, task::{executor::Executor, Task}, }; @@ -104,6 +104,7 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { executor.spawn(Task::new(example_task())); executor.spawn(Task::new(keyboard::print_keypresses())); executor.spawn(Task::new(udp_echo::udp_echo_server())); + executor.spawn(Task::new(tcp_echo::tcp_echo_server())); executor.run(); } diff --git a/kernel/src/network/README.md b/kernel/src/network/README.md index ab4ddc2..ccbb4fb 100644 --- a/kernel/src/network/README.md +++ b/kernel/src/network/README.md @@ -8,7 +8,9 @@ TODO [x] RTL8139 Driver Code [x] Ethernet, IP, UDP, ARP, DHCP [x] RawSocket API -[] Better Socket API +[x] Better Socket API +[x] Async IO +[x] Timeouts [] Refactor so that all of networking is tested [] Refactor to include more documentation on the network module [] Refactor to verify checksums @@ -19,3 +21,15 @@ TODO [] TCP [] Refactor to be all constants [] search for todo and fix thoses + +## Receiving packets + +Firstly, in rtl8139.rs, we receive a device interrupt that a new packet has arrived (length + data). To keep the interrupt handler short, we copy the contents of the data to a buffer and wake the processor. This processor is in processing.rs and async-ly deserializes the buffer. After doing so, we determine which port it belongs to. See the discussion below about the design difference compared to the traditional notion of ports. (For tcp, we do a little more processing before sending it to the port by acknowledging the receipt of data, tcp_session.rs). Furthermore, once the port is deteremined, we can check the open ports to see if something truly needs the data. If so, we insert it into a port-specific buffer and wake the raw_socket (raw_socket.rs). The raw socket is a stream which is queried by a Socket (socket.rs) that owns it. This allows for async read. + +### u64 "ports" + +Traditionally, outward facing ports are from 0-65535, or a u16. But to handle sockets that don't necessarily bind to these ports, the internal ports for receiving data goes up to a u64. As a result, we have a whole space to communicate on. For instance, ARP, while waiting for a response, listens on "port" 65537. This reservation allows the correct direction of packets to the interpretation of them. + +### Timeouts + +Because we are dealing with async in our socket api, when we put the raw_socket task to sleep - it won't wake up until a packet is received so that processing.rs can wake it. As a result, I introduced timeouts. Where the socket will either wake on receipt of a packet or by the handler of the timer interrupt. We register a timeout by saving a waker and how many epochs (iterations of the timer = ~1/18 of a second) into a minheap. Then in the handler, we can draw the minimum value from the heap and check if we've reached that epoch. If so, we wake the task and draw the next minimum. diff --git a/kernel/src/network/TODO.md b/kernel/src/network/TODO.md index 064f2f1..5207120 100644 --- a/kernel/src/network/TODO.md +++ b/kernel/src/network/TODO.md @@ -5,8 +5,10 @@ * Refactor to verify checksums * Verify other parts of the packet * Fix synchronization to be much cleaner +* Fix checksums to be baked-in * Clean up ugly stuff * DHCP parse additional options * TCP * Refactor to be all constants -* search for todo and fix thoses +* Search for todo and fix thoses +* Benchmarking diff --git a/kernel/src/task/timeout.rs b/kernel/src/task/timeout.rs new file mode 100644 index 0000000..fb32813 --- /dev/null +++ b/kernel/src/task/timeout.rs @@ -0,0 +1,95 @@ +use core::{sync::atomic::AtomicU64, cell::RefCell, task::Waker}; +use alloc::collections::BinaryHeap; +use lazy_static::lazy_static; +use x86_64::instructions::interrupts; + +static mut INTERRUPT_COUNTER: u64 = 0; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct TimeoutID(u64); + +impl TimeoutID { + pub fn new() -> Self { + static NEXT_ID: AtomicU64 = AtomicU64::new(0); + TimeoutID(NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed)) + } +} + +struct TimeoutEntry { + id: TimeoutID, + epochs: u64, + waker: Waker, + cancelled: bool, +} + +impl TimeoutEntry { + pub fn new(id: TimeoutID, epochs: u64, waker: Waker) -> Self { + TimeoutEntry { id, epochs, waker, cancelled: false } + } +} + +impl PartialEq for TimeoutEntry { + fn eq(&self, other: &Self) -> bool { + self.epochs == other.epochs + } +} + +impl PartialOrd for TimeoutEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Eq for TimeoutEntry {} + +impl Ord for TimeoutEntry { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.epochs.cmp(&other.epochs).reverse() + } +} + +lazy_static! { + static ref TIMEOUT_QUEUE: spin::Mutex>> = spin::Mutex::new(BinaryHeap::new()); +} + +pub fn read_interrupt_counter() -> u64 { + unsafe { INTERRUPT_COUNTER } +} + +// Each epoch is ~1/18 of a second, experimentally +pub fn register_timeout(after_epochs: u16, waker: Waker) -> TimeoutID { + let timeout_id = TimeoutID::new(); + interrupts::without_interrupts(|| { + let mut timeout_queue = TIMEOUT_QUEUE.lock(); + timeout_queue.push(RefCell::new(TimeoutEntry::new(timeout_id, unsafe { INTERRUPT_COUNTER } + after_epochs as u64, waker))); + }); + timeout_id +} + +pub fn cancel_timeout(id: TimeoutID) { + interrupts::without_interrupts(|| { + let timeout_queue = TIMEOUT_QUEUE.lock(); + for entry in timeout_queue.iter() { + if entry.borrow().id.0 == id.0 { + entry.borrow_mut().cancelled = true; + break; + } + } + }); +} + +// Only run from the interrupt context +pub fn poll_timeouts() { + let mut timeout_queue = TIMEOUT_QUEUE.lock(); + unsafe { INTERRUPT_COUNTER += 1 }; + while let Some(timeout_entry) = timeout_queue.peek() { + if timeout_entry.borrow().epochs <= unsafe { INTERRUPT_COUNTER } { + if !timeout_entry.borrow().cancelled { + timeout_entry.borrow().waker.wake_by_ref(); + } + timeout_queue.pop(); + } else { + break; + } + } +} \ No newline at end of file diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index f06a53f..459702a 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -1,28 +1,40 @@ use crate::{ - print, println, network::socket::{SocketType, Socket}, + print, println, network::{socket::{SocketType, Socket}, raw_socket::NetworkErrors}, }; use alloc::string::String; pub async fn udp_echo_server() { - let socket_or_err = Socket::open(SocketType::UDP, 5554).await; + let socket_or_err = Socket::open(SocketType::UDP, 5554, 15).await; match socket_or_err { Ok(mut socket) => { // Listen for a single connection socket.listen().await; loop { let data_or_err = socket.read(0).await; - if let Ok(data) = data_or_err { + if let Ok(mut data) = data_or_err { let user_message = String::from_utf8(data.clone()); match user_message { - Ok(message) => print!("[USER] {}", message), + Ok(message) => { + print!("[USER] {}", message); + if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { + println!("Closing socket"); + let _ = socket.write(&mut ("Closing socket...".as_bytes().to_vec())); + socket.close(); + return; + } + }, Err(err) => println!("[USER-ERR] {:?}", err), } - let res_or_err = socket.write(data); + let res_or_err = socket.write(&mut data); if let Err(err) = res_or_err { println!("[ERR] {:?}", err); break; } } else if let Err(err) = data_or_err { + if err == NetworkErrors::Timeout { + println!("[INFO] Socket had a timeout"); + break; + } println!("[ERR] {:?}", err); break; } From 04f8683773c3d8a8417c837a49d2f023c9d43cfb Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 12 Nov 2023 15:28:04 -0500 Subject: [PATCH 12/36] TCP handshaking --- kernel/src/network/constants.rs | 6 +- kernel/src/network/init.rs | 4 +- kernel/src/network/processing.rs | 112 +++++++++++++++-- kernel/src/network/raw_socket.rs | 41 ++++-- kernel/src/network/rtl8139.rs | 17 ++- kernel/src/network/socket.rs | 161 +++++++++++++++++------- kernel/src/network/tcp.rs | 202 ++++++++++++++++++++++++++++++ kernel/src/network/tcp_session.rs | 143 +++++++++++++++++++++ kernel/src/task/mod.rs | 2 + kernel/src/task/tcp_echo.rs | 47 +++++++ src/main.rs | 3 +- 11 files changed, 662 insertions(+), 76 deletions(-) create mode 100644 kernel/src/network/tcp_session.rs create mode 100644 kernel/src/task/tcp_echo.rs diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index 584f21f..fa956b9 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -3,9 +3,9 @@ pub const BROADCAST_ADDR: u32 = 0xFFFFFFFF; pub const BROADCAST_MAC: u64 = 0xFFFFFFFFFFFF; // Common port numbers -pub const DHCP_CLIENT_PORT: u32 = 68; -pub const DHCP_SERVER_PORT: u32 = 67; -pub const ARP_PORT: u32 = u16::MAX as u32 + 2; // we are using ports above u16 to allow for extended our open "ports" map +pub const DHCP_CLIENT_PORT: u64 = 68; +pub const DHCP_SERVER_PORT: u64 = 67; +pub const ARP_PORT: u64 = u16::MAX as u64 + 2; // we are using ports above u16 to allow for extended our open "ports" map // RTL device specific info pub const RTL_VEND: u32 = 0x10EC; // Vendor ID for the rtl8139 diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index b6cb7d6..f2b7a6f 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -1,5 +1,4 @@ use futures_util::StreamExt; -use x86_64::instructions::hlt; use super::constants::{BROADCAST_ADDR, BROADCAST_MAC, DHCP_CLIENT_PORT}; use super::processing; @@ -20,7 +19,7 @@ pub fn init() { pub async fn init_dhcp(wait_timeout: u8) -> bool { // open a socket (shouldn't fail since we are in the init process) - let mut socket = RawSocket::new(DHCP_CLIENT_PORT).unwrap(); + let mut socket = RawSocket::new(DHCP_CLIENT_PORT, 1).unwrap(); disable_network_interrupts(); let rtl_dev_guard = NET_INFO.lock(); @@ -55,7 +54,6 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { pkt_data = dhcp_res; break; } else { - hlt(); disable_network_interrupts(); { let rtl_dev_guard = NET_INFO.lock(); diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 1b121a2..24c1626 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -1,9 +1,23 @@ -use alloc::{vec::Vec, collections::VecDeque}; +use alloc::{collections::VecDeque, vec::Vec}; use conquer_once::spin::OnceCell; use crossbeam_queue::ArrayQueue; use futures_util::{task::AtomicWaker, Stream, StreamExt}; -use crate::{println, network::{raw_socket::wake_sockets, layer::{PacketData, Layer}, arp_table::ArpEntry, constants::{BROADCAST_ADDR, ARP_PORT}, arp::ArpPacket, ethernet::{EthernetPacket, EthType}, rtl8139::{NET_INFO, disable_network_interrupts, enable_network_interrupts}}}; +use crate::{ + network::{ + arp::ArpPacket, + arp_table::ArpEntry, + constants::{ARP_PORT, BROADCAST_ADDR, TCP_SYN}, + ethernet::{EthType, EthernetPacket}, + ip::{IPPacket, Protocol}, + layer::{Layer, PacketData, LayerType}, + raw_socket::wake_sockets, + rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, + tcp::TCPPacket, + tcp_session::TCPSession, + }, + println, +}; use core::{ pin::Pin, @@ -36,9 +50,7 @@ impl Stream for PendingProcessingStream { let data = queue.pop(); PROCESS_VEC_WAKER.register(cx.waker()); match data { - Some(pkt_data) => { - Poll::Ready(Some(pkt_data)) - } + Some(pkt_data) => Poll::Ready(Some(pkt_data)), None => Poll::Pending, } } @@ -60,6 +72,11 @@ pub async fn init_packet_processing() { let mut raw_packets = PendingProcessingStream::new(); while let Some(pkt_data) = raw_packets.next().await { let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); + if amount_parsed_and_pkt.1.get_type() == LayerType::UNDEF + || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP { + // Don't deal with unrecognized packets (will fail assert) + continue; + } assert!(amount_parsed_and_pkt.0 == pkt_data.len() || pkt_data.len() < 64); // Try to get the device info disable_network_interrupts(); @@ -106,7 +123,7 @@ pub async fn init_packet_processing() { } } PacketData::DHCP(dhcp) => { - let dst_port = dhcp.udp_packet.dest_port.val() as u32; + let dst_port = dhcp.udp.dest_port.val() as u64; println!("[HANDLER] Found DHCP packet"); if rtl_dev_info.open_ports.contains(&dst_port) { println!("[HANDLER] Port {} is open", dst_port); @@ -123,7 +140,7 @@ pub async fn init_packet_processing() { } } PacketData::UDP(udp) => { - let dst_port = udp.dest_port.val() as u32; + let dst_port = udp.dest_port.val() as u64; if rtl_dev_info.open_ports.contains(&dst_port) { // if we are listening on the port, try to insert it into the map if !rtl_dev_info.ports.contains_key(&dst_port) { @@ -137,6 +154,87 @@ pub async fn init_packet_processing() { wake_sockets(dst_port); } } + PacketData::TCP(tcp) => { + println!("[INFO] Got tcp packet"); + let dst_port = tcp.dest_port.val() as u64; + if rtl_dev_info.open_ports.contains(&dst_port) { + // if we are listening on the port, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&dst_port) { + rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + } + let session_key = TCPSession::gen_session_key( + tcp.ip.source_ip.val(), + tcp.src_port.val(), + tcp.dest_port.val(), + ); + // Open up a session + if !rtl_dev_info.tcp_sessions.contains_key(&session_key) { + if (tcp.get_flags() & TCP_SYN) == 0 { + // Ignore requests when there is no request for syncing + enable_network_interrupts(); + return; + } + // Compact the first packet we receive as the session creation + let eth_layer = EthernetPacket::gen( + tcp.ip.eth.src_mac.val(), + tcp.ip.eth.dest_mac.val(), + EthType::IPv4, + ); + let ip_layer = IPPacket::gen( + eth_layer, + 0, // leaving size undefined for the template + Protocol::TCP, + tcp.ip.destination_ip.val(), + tcp.ip.source_ip.val(), + ); + let tcp_layer = + TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); + let session = TCPSession::new( + tcp_layer, + tcp.ip.source_ip.val(), + tcp.src_port.val(), + tcp.dest_port.val(), + ); + // Lets push to the port -- we are listening then we need to create a new session + rtl_dev_info + .ports + .get_mut(&dst_port) + .unwrap() + .push_back(PacketData::TCP(tcp.clone())); + wake_sockets(dst_port); + rtl_dev_info + .tcp_sessions + .insert(session.session_key(), session); + } + // todo: SYN COOKIES + // todo: Should I have a buffer limit? -- + // ? I think maybe no, because upstream data should be prioritized -- + let tcp_session = rtl_dev_info.tcp_sessions.get_mut(&session_key).unwrap(); + // Generate an acknowledgement and determine if the tcp packet has data + let ack_pkt = tcp_session.gen_acknowledgement(&tcp); + if let Some(response) = ack_pkt.0 { + // If we got a response packet to send back + // todo: what happens if our response is dropped... we need to re-ack? + // generally if no ack is receeived, the host will send another transmission + // ALSO we don't know if our ack was received or not, so we just wait for another transmission + // we could also get duplicate data tho? so we need to identify this case + rtl_dev_info.send_packet(&response.serialize()); + } + println!("[INFO] Packet has data {}", ack_pkt.1); + if ack_pkt.1 { + // Push the data to the application socket + if rtl_dev_info.open_ports.contains(&session_key) { + // if we are listening on the session, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&session_key) { + rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + } + // Insert the data and wake the socket + rtl_dev_info.ports.get_mut(&session_key).unwrap().push_back(PacketData::TCP(tcp)); + wake_sockets(session_key); + } + } + } + } _ => {} // ignore others } enable_network_interrupts(); diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index 8a705ce..5b021d5 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -1,6 +1,8 @@ use futures_util::{Stream, task::AtomicWaker}; use hashbrown::HashMap; +use crate::{task::timeout::{register_timeout, cancel_timeout, TimeoutID}, println}; + use super::{ layer::PacketData, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, @@ -9,25 +11,28 @@ use core::task::{Context, Poll}; use lazy_static::lazy_static; lazy_static! { - pub static ref NEW_PACKET_WAKER: spin::Mutex> = spin::Mutex::new(HashMap::new()); + pub static ref NEW_PACKET_WAKER: spin::Mutex> = spin::Mutex::new(HashMap::new()); } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum NetworkErrors { PortInUse, NoAvailablePort, NonexistentHost, SocketInServerMode, + FeatureNotAvailableYet, + Timeout, } -// todo: Implement a socket for user-space... - pub struct RawSocket { - port: u32, + port: u64, + timeout_in_epochs: u16, + timeout_active: bool, + timeout_id: TimeoutID, } impl RawSocket { - pub fn new(port: u32) -> Result { + pub fn new(port: u64, timeout_in_epochs: u16) -> Result { disable_network_interrupts(); let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); @@ -41,7 +46,7 @@ impl RawSocket { // and allocate a waker NEW_PACKET_WAKER.lock().insert(port, AtomicWaker::new()); enable_network_interrupts(); - Ok(RawSocket { port }) + Ok(RawSocket { port, timeout_in_epochs, timeout_active: false, timeout_id: TimeoutID::new(), }) } fn try_get_packet_inner(&self) -> Option { @@ -66,6 +71,12 @@ impl RawSocket { let vec = rtl_dev_info.ports.get_mut(&self.port); vec.unwrap().clear(); } + // remove tcp session information + // todo: what about udp? what about sending a FIN + /*let session_key = TCPSession::gen_session_key(self.dest_ip, self.dest_port); + if rtl_dev_info.tcp_sessions.contains_key(&session_key) { + rtl_dev_info.tcp_sessions.remove(&session_key); + }*/ enable_network_interrupts(); } } @@ -73,23 +84,31 @@ impl RawSocket { impl Stream for RawSocket { type Item = PacketData; - // todo: What happens if we have packet loss? Won't we deadlock because we will be pending and never get a notification? - fn poll_next(self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(mut self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { disable_network_interrupts(); let pkt = self.try_get_packet_inner(); enable_network_interrupts(); // listen on the port with the waker - NEW_PACKET_WAKER.lock()[&self.port].register(cx.waker()); + let locked_waker_map = NEW_PACKET_WAKER.lock(); + locked_waker_map[&self.port].register(cx.waker()); if pkt.is_some() { + cancel_timeout(self.timeout_id); + self.timeout_active = false; Poll::Ready(pkt) + } else if self.timeout_active { + self.timeout_active = false; + Poll::Ready(None) } else { + // Register a timeout to ensure wakeup at some point + self.timeout_active = true; + self.timeout_id = register_timeout(self.timeout_in_epochs, cx.waker().clone()); Poll::Pending } } } /// Wake sockets by port -pub(crate) fn wake_sockets(port: u32) { +pub(crate) fn wake_sockets(port: u64) { // wake the port up NEW_PACKET_WAKER.lock()[&port].wake(); } \ No newline at end of file diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 4752d84..e380d76 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -1,3 +1,5 @@ +use core::cmp::max; + use alloc::vec::Vec; use alloc::{collections::VecDeque, vec}; use hashbrown::{HashMap, HashSet}; @@ -23,6 +25,7 @@ use super::constants::{ INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG, }; use super::processing::add_pkt_data; +use super::tcp_session::TCPSession; use super::{ arp_table::ArpEntry, devices::{Device, PCIClassCodes}, @@ -191,8 +194,9 @@ pub struct RTL8139 { pub my_ip_address: Option, pub dhcp_server_ip: Option, pub mac_address: Option, - pub open_ports: HashSet, - pub ports: HashMap>, + pub open_ports: HashSet, + pub tcp_sessions: HashMap, + pub ports: HashMap>, pub arp_table: Vec, } @@ -268,6 +272,7 @@ impl RTL8139 { open_ports: HashSet::with_capacity(10), ports: HashMap::with_capacity(10), arp_table: Vec::new(), + tcp_sessions: HashMap::with_capacity(10), }); } None @@ -388,6 +393,12 @@ impl RTL8139 { for (i, item) in packet_data.iter().enumerate() { unsafe { *(virtual_buffer_ptr.wrapping_add(i)) = *item }; } + if packet_data.len() < 60 { + for j in 0..(60 - packet_data.len()) { + // pad up to 60 with 0s + unsafe { *(virtual_buffer_ptr.wrapping_add(packet_data.len() + j)) = 0 }; + } + } // TODO: Make this part of self... let reg = TRANSMIT_REG[unsafe { TRANSMIT_IDX as usize }]; @@ -397,7 +408,7 @@ impl RTL8139 { let mut cmd_port = Port::::new((io_base.unwrap() + cmd) as u16); unsafe { reg_port.write(virtual_buffer.as_u64() as u32); - cmd_port.write(packet_data.len() as u32); + cmd_port.write(max(packet_data.len(), 60) as u32); } // Send the packet from the buffer unsafe { diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index d9d9fbd..7eb427f 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -1,19 +1,20 @@ +use core::cmp::max; + use alloc::vec; use alloc::vec::Vec; use futures_util::StreamExt; -use x86_64::instructions::hlt; -use crate::network::layer::LayerType; +use crate::{network::layer::LayerType, println}; use super::{ arp::ArpPacket, - constants::{ARP_PORT, BROADCAST_MAC}, + constants::{ARP_PORT, BROADCAST_MAC, TCP_PSH}, ethernet::{self, EthType, EthernetPacket}, ip::{IPPacket, Protocol}, layer::{HasChecksum, Layer, PacketData}, raw_socket::{NetworkErrors, RawSocket}, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, - udp::UDPPacket, + udp::UDPPacket, tcp_session::TCPSession, }; pub struct NetworkQuery {} @@ -41,7 +42,7 @@ impl NetworkQuery { enable_network_interrupts(); // wait for response - let mut socket = RawSocket::new(ARP_PORT).unwrap(); + let mut socket = RawSocket::new(ARP_PORT, 3).unwrap(); let mut timeout = 0; loop { if let Some(pkt) = socket.next().await { @@ -51,17 +52,16 @@ impl NetworkQuery { let arp_pkt = pkt.unwrap_arp(); return Some(arp_pkt.sender_mac.val()); } else { - hlt(); disable_network_interrupts(); { - let rtl_dev_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); - rtl_dev_info.send_packet(&arp_layer.serialize()); + let rtl_dev_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); + rtl_dev_info.send_packet(&arp_layer.serialize()); } enable_network_interrupts(); } timeout += 1; - if timeout == wait_timeout * 18 { + if timeout == wait_timeout * 6 { socket.close(); return None; } @@ -86,16 +86,17 @@ pub struct Socket { socket_state: SocketState, raw_socket: RawSocket, dest_port: u16, - dest_address: u32, + dest_ip: u32, dest_mac: u64, pub src_port: u16, src_address: u32, src_mac: u64, + wait_timeout: u16, } impl Socket { // Can't send yet - pub async fn open(socket_type: SocketType, src_port: u16) -> Result { + pub async fn open(socket_type: SocketType, src_port: u16, wait_timeout: u16) -> Result { let mut chosen_src_port = src_port; disable_network_interrupts(); let rtl_dev_info_locked = NET_INFO.lock(); @@ -106,7 +107,7 @@ impl Socket { let open_ports = &rtl_dev_info.open_ports; // Linear probe to see if any ports are open for i in 1000..u16::MAX { - if !open_ports.contains(&(i as u32)) { + if !open_ports.contains(&(i as u64)) { chosen_src_port = i; break; } @@ -117,45 +118,67 @@ impl Socket { if chosen_src_port == 0 { return Err(NetworkErrors::NoAvailablePort); } - let raw_socket = RawSocket::new(chosen_src_port as u32); + let raw_socket = RawSocket::new(chosen_src_port as u64, max(wait_timeout * 18, 1)); match raw_socket { Ok(socket) => Ok(Socket { socket_type, raw_socket: socket, socket_state: SocketState::Listening, dest_port: 0, - dest_address: 0, + dest_ip: 0, dest_mac: 0, src_port, src_address, src_mac, + wait_timeout, }), Err(err) => Err(err), } } // Will listen for new connections and create new sessions - // UDP can only listen for one connection (and thus will return none). + // UDP can only listen for one connection (and thus will return none). -- UDP is connectionless pub async fn listen(&mut self) -> Option { + if self.socket_state != SocketState::Listening { + return None; // todo: this is failing silently + } loop { if let Some(pkt) = self.raw_socket.next().await { - if pkt.get_type() != LayerType::UDP { - continue; - } - let udp_pkt = pkt.unwrap_udp(); - self.dest_port = udp_pkt.src_port.val(); - self.dest_mac = udp_pkt.ip_packet.ethernet_packet.src_mac.val(); - self.dest_address = udp_pkt.ip_packet.source_ip.val(); - self.socket_state = SocketState::Ready; - disable_network_interrupts(); - let mut rtl_dev_info_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); - // re-enqueue the packet - if let Some(vec) = rtl_dev_info.ports.get_mut(&(self.src_port as u32)) { - vec.push_front(PacketData::UDP(udp_pkt)); + if pkt.get_type() == LayerType::UDP && self.socket_type == SocketType::UDP { + let udp_pkt = pkt.unwrap_udp(); + self.dest_port = udp_pkt.src_port.val(); + self.dest_mac = udp_pkt.ip.eth.src_mac.val(); + self.dest_ip = udp_pkt.ip.source_ip.val(); + self.socket_state = SocketState::Ready; + disable_network_interrupts(); + let mut rtl_dev_info_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // re-enqueue the packet + if let Some(vec) = rtl_dev_info.ports.get_mut(&(self.src_port as u64)) { + vec.push_front(PacketData::UDP(udp_pkt)); + } + enable_network_interrupts(); + return None; + } else if pkt.get_type() == LayerType::TCP && self.socket_type == SocketType::TCP { + let tcp_pkt = pkt.unwrap_tcp(); + let dest_address = tcp_pkt.ip.source_ip.val(); + let dest_port = tcp_pkt.src_port.val(); + let session_key = TCPSession::gen_session_key(dest_address, dest_port, self.src_port); + let raw_socket = RawSocket::new(session_key, max(self.wait_timeout * 18, 1)).unwrap(); + println!("[INFO] Spawned new TCP session"); + return Some(Socket { + socket_type: SocketType::TCP, + raw_socket, + socket_state: SocketState::Ready, + dest_port, + dest_ip: dest_address, + dest_mac: tcp_pkt.ip.eth.src_mac.val(), + src_port: self.src_port, + src_address: self.src_address, + src_mac: self.src_mac, + wait_timeout: self.wait_timeout, + }); } - enable_network_interrupts(); - return None; } } } @@ -165,6 +188,7 @@ impl Socket { dest_address: u32, dest_port: u16, src_port: u16, + wait_timeout: u16, ) -> Result { let mut chosen_src_port = src_port; disable_network_interrupts(); @@ -176,7 +200,7 @@ impl Socket { let open_ports = &rtl_dev_info.open_ports; // Linear probe to see if any ports are open for i in 1000..u16::MAX { - if !open_ports.contains(&(i as u32)) { + if !open_ports.contains(&(i as u64)) { chosen_src_port = i; break; } @@ -191,18 +215,19 @@ impl Socket { if chosen_src_port == 0 { return Err(NetworkErrors::NoAvailablePort); } - let raw_socket = RawSocket::new(chosen_src_port as u32); + let raw_socket = RawSocket::new(chosen_src_port as u64, max(wait_timeout * 18, 1)); match raw_socket { Ok(socket) => Ok(Socket { socket_type, socket_state: SocketState::Ready, raw_socket: socket, dest_port, - dest_address, + dest_ip: dest_address, dest_mac: dest_mac.unwrap(), src_port, src_address, src_mac, + wait_timeout, }), Err(err) => Err(err), } @@ -212,22 +237,62 @@ impl Socket { self.raw_socket.close(); } - pub async fn read(&mut self, size: u32) -> Result, NetworkErrors> { - assert!(self.socket_type == SocketType::UDP); + pub async fn read(&mut self, size: usize) -> Result, NetworkErrors> { if self.socket_state == SocketState::Listening { return Err(NetworkErrors::SocketInServerMode); } - if let Some(pkt) = self.raw_socket.next().await { - if pkt.get_type() != LayerType::UDP { - return Ok(vec![]); + match self.socket_type { + SocketType::UDP => self.read_udp(size).await, + SocketType::TCP => self.read_tcp(size).await, + } + } + + async fn read_udp(&mut self, size: usize) -> Result, NetworkErrors> { + loop { + if let Some(pkt) = self.raw_socket.next().await { + if pkt.get_type() == LayerType::UDP { + let udp_pkt = pkt.unwrap_udp(); + return Ok(udp_pkt.data); + } + } else { + return Err(NetworkErrors::Timeout); } - let udp_pkt = pkt.unwrap_udp(); - return Ok(udp_pkt.data); } - Ok(vec![]) } - pub fn write(&self, data: Vec) -> Result<(), NetworkErrors> { + async fn read_tcp(&mut self, size: usize) -> Result, NetworkErrors> { + loop { + let mut res_vec = vec![]; + if let Some(pkt) = self.raw_socket.next().await { + if pkt.get_type() == LayerType::TCP { + let mut tcp_pkt = pkt.unwrap_tcp(); + res_vec.append(&mut tcp_pkt.data); + if tcp_pkt.get_flags() & TCP_PSH != 0 { + return Ok(res_vec); + } + } + if size <= res_vec.len() { + return Ok(res_vec); + } + } else { + return Err(NetworkErrors::Timeout); + } + } + } + + pub fn write(&self, data: &mut Vec) -> Result { + match self.socket_type { + SocketType::UDP => self.write_udp(data), + SocketType::TCP => self.write_tcp(data), + } + } + + fn write_tcp(&self, data: &mut Vec) -> Result { + let _ = data; + Err(NetworkErrors::FeatureNotAvailableYet) + } + + fn write_udp(&self, data: &mut Vec) -> Result { assert!(self.socket_type == SocketType::UDP); if self.socket_state == SocketState::Listening { return Err(NetworkErrors::SocketInServerMode); @@ -239,17 +304,17 @@ impl Socket { udp_size, Protocol::UDP, self.src_address, - self.dest_address, + self.dest_ip, ); let data_len = data.len(); let mut udp_layer = UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); - udp_layer.data = data; + udp_layer.data = data.to_vec(); // todo: split the data and return the amount actually written! let data_2_send = udp_layer.serialize(); let start_udp = data_2_send.len() - (UDPPacket::packet_size() as usize + data_len); let start_ip = start_udp - (IPPacket::packet_size() as usize); udp_layer - .ip_packet + .ip .calculate_checksum(&data_2_send[start_ip..start_udp]); udp_layer.calculate_checksum(&data_2_send[start_udp..]); let data_2_send_final = udp_layer.serialize(); @@ -260,6 +325,6 @@ impl Socket { .unwrap() .send_packet(&data_2_send_final); enable_network_interrupts(); - Ok(()) + Ok(data_len as u16) } } diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index e69de29..5ade514 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -0,0 +1,202 @@ +use alloc::vec; +use alloc::vec::Vec; + +use crate::println; + +use super::{ + bytefield::{Bytefield16, Bytefield32}, + ip::IPPacket, + layer::{HasChecksum, Layer, LayerType}, +}; + +#[derive(Debug, Clone)] +pub struct TCPPacket { + pub ip: IPPacket, + pub src_port: Bytefield16, // 2 bytes + pub dest_port: Bytefield16, // 2 bytes + pub seq_num: Bytefield32, // 4 bytes + pub ack_num: Bytefield32, // 4 bytes + pub flags: Bytefield16, // 2 bytes + pub sliding_window: Bytefield16, // 2 bytes + pub checksum: Bytefield16, // 2 bytes + pub urgent: Bytefield16, // +2 more, 20 bytes up until here + pub options: Vec, + pub data: Vec, +} + +impl TCPPacket { + pub fn new() -> Self { + TCPPacket { + ip: IPPacket::new(), + src_port: Bytefield16::new(0), + dest_port: Bytefield16::new(0), + seq_num: Bytefield32::new(0), + ack_num: Bytefield32::new(0), + flags: Bytefield16::new(0), + sliding_window: Bytefield16::new(0), + checksum: Bytefield16::new(0), + urgent: Bytefield16::new(0), + options: vec![], + data: vec![], + } + } + + pub fn gen(ip_packet: IPPacket, src_port: u16, dest_port: u16) -> Self { + TCPPacket { + ip: ip_packet, + src_port: Bytefield16::new(src_port), + dest_port: Bytefield16::new(dest_port), + seq_num: Bytefield32::new(0), + ack_num: Bytefield32::new(0), + flags: Bytefield16::new(6 << 12), // a value of 6 in the header_offset (6*4 = 24 bits -> b/c no options) + sliding_window: Bytefield16::new(u16::MAX), // sliding window :(, pain to implement... we can just allow unlimited data but thats insecure... Also we would like to keep track of how much we are allowed to send + checksum: Bytefield16::new(0), + urgent: Bytefield16::new(0), + options: vec![0x02, 0x04, 0x05, 0xb4], // need to get to 24 + // todo: we set the mss here manually -- fix + data: vec![], + } + } + + pub fn options_size() -> u16 { + 4 + } + + // N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) + // Get the header offset + pub fn get_header_offset(&self) -> u8 { + // convert header offset into bytes + ((self.flags.val() >> 12) as u8) * 4 + } + + // Get the flags + pub fn get_flags(&self) -> u8 { + (self.flags.val() & 0x00FF) as u8 + } + + // N.B. Setting operations swap endianness because when we use val, we are in network-order + // Turn off the flags provided + pub fn turn_off_flags(&mut self, flags: u8) { + let new_flags = self.flags.swapped_endianness().val() & !(flags as u16); + self.flags = Bytefield16::new(new_flags); + } + + // Turn on the flags provided + pub fn turn_on_flags(&mut self, flags: u8) { + let new_flags = self.flags.swapped_endianness().val() | (flags as u16); + self.flags = Bytefield16::new(new_flags); + } + + // Get total length of the tcp portion + pub fn total_size(&self) -> u16 { + (20 + self.options.len() + self.data.len()) as u16 + } +} + +impl Layer for TCPPacket { + type Input = IPPacket; + fn parse(ip_layer: IPPacket, bytevec: &[u8]) -> (Self, usize, LayerType) + where + Self: Sized, + { + let mut packet = TCPPacket::new(); // create an empty packet + // Read 14 bytes + let mut i = 0; + packet.ip = ip_layer; + packet.src_port = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.dest_port = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.seq_num = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.ack_num = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.flags = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.sliding_window = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); + packet.urgent = Bytefield16::read_inc(&bytevec[i..], &mut i); + // read remaining bytes and place them into the data buffer + println!("[INFO] Packet header offset {}", packet.get_header_offset()); + println!("[INFO] Buffer is {}", bytevec.len()); + for _ in 0..(packet.get_header_offset() - 20) { + packet.options.push(bytevec[i]); + i += 1; + } + assert!(i == packet.get_header_offset().into()); // Valid i is header_length + let data_size = packet.ip.total_length.val() - packet.get_header_offset() as u16 - IPPacket::packet_size(); + for _ in 0..data_size { + packet.options.push(bytevec[i]); + i += 1; + } + (packet, i, LayerType::UNDEF) + } + + fn serialize(&self) -> alloc::vec::Vec { + let mut res = vec![]; + res.extend(self.ip.serialize()); + res.extend(self.src_port.data); + res.extend(self.dest_port.data); + res.extend(self.seq_num.data); + res.extend(self.ack_num.data); + res.extend(self.flags.data); + res.extend(self.sliding_window.data); + res.extend(self.checksum.data); + res.extend(self.urgent.data); + res.extend(&self.options); + res.extend(&self.data); + assert!( + res.len() == (20 + self.ip.serialize().len() + self.options.len() + self.data.len()) + ); + res + } + + fn packet_size() -> u16 { + 20 + } +} + +impl HasChecksum for TCPPacket { + fn calculate_checksum(&mut self, data: &[u8]) { + // Starting vars + let mut sum: u32 = 0; + // calculating checksum on serialized bytefield (so its network byte order and must be swapped) + let mut tcp_len = self.total_size(); + + // First we do the IP as a pseduo header + let ip = &self.ip; + sum += (ip.source_ip.data[0] as u32) | (ip.source_ip.data[1] as u32) << 8; + sum += (ip.source_ip.data[2] as u32) | (ip.source_ip.data[3] as u32) << 8; + sum += (ip.destination_ip.data[0] as u32) | (ip.destination_ip.data[1] as u32) << 8; + sum += (ip.destination_ip.data[2] as u32) | (ip.destination_ip.data[3] as u32) << 8; + + // Sum protocol and length + let protocol = Bytefield16::new(ip.protocol as u16); + sum += (protocol.data[0] as u32) | (protocol.data[1] as u32) << 8; + let segment_size = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size()); + sum += (segment_size.data[0] as u32) | (segment_size.data[1] as u32) << 8; + + // Zero the checksum field + self.checksum = Bytefield16::new(0); + + // Sum the body + let mut ptr = 0; + while tcp_len > 1 { + sum += (data[ptr] as u32) | ((data[ptr + 1] as u32) << 8); + tcp_len -= 2; + ptr += 2; + } + + if data.len() % 2 == 1 { + // Add the padding if the packet length is odd + sum += data[ptr] as u32; + } + + // Add the carries + while sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16); + } + + // One's complement and swap the bytes because we did our sum in big endian + // (and the bytefield will try to convert to big endian) + let res = u16::swap_bytes(!sum as u16); + + // Set the checksum + self.checksum = Bytefield16::new(res); + } +} diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs new file mode 100644 index 0000000..a691bd6 --- /dev/null +++ b/kernel/src/network/tcp_session.rs @@ -0,0 +1,143 @@ +use super::{ + bytefield::{Bytefield16, Bytefield32}, + constants::{TCP_ACK, TCP_FIN, TCP_RST, TCP_SYN}, + ethernet::EthernetPacket, + ip::IPPacket, + layer::{HasChecksum, Layer}, + tcp::TCPPacket, +}; + +#[derive(Debug, PartialEq, Eq)] +enum TCPSessionState { + Waiting, + Syncing, + Established, + Closing, + Closed, +} + +pub struct TCPSession { + session_template: TCPPacket, + pub dest_ip: u32, + pub dest_port: u16, + pub src_port: u16, + pub send_acked_up_to: u32, + pub recv_acked_up_to: u32, + pub window_size: u16, + session_state: TCPSessionState, +} + +impl TCPSession { + pub fn new(session_template: TCPPacket, dest_ip: u32, dest_port: u16, src_port: u16) -> Self { + TCPSession { + session_template, + send_acked_up_to: 0, + recv_acked_up_to: 0, + dest_ip, + dest_port, + src_port, + window_size: u16::MAX, + session_state: TCPSessionState::Waiting, + } + } + + // todo: replace for random generated number + pub fn gen_starting_seq_num() -> u32 { + 0 + } + + pub fn session_key(&self) -> u64 { + Self::gen_session_key(self.dest_ip, self.dest_port, self.src_port) + } + + pub fn gen_session_key(dest_ip: u32, dest_port: u16, src_port: u16) -> u64 { + (dest_ip as u64) << 32 | (dest_port as u64) << 16 | src_port as u64 + } + + pub fn close(&mut self) { + self.session_state = TCPSessionState::Closing; + // get rtl object + // transition into + } + + pub fn reset(&mut self) { + // get rtl object + // send a rst and immediately stop receiving + } + + pub fn gen_acknowledgement(&mut self, request: &TCPPacket) -> (Option, bool) { + let has_syn_flag = (request.get_flags() & TCP_SYN) != 0; + let has_ack_flag = (request.get_flags() & TCP_ACK) != 0; + let has_fin_flag = (request.get_flags() & TCP_FIN) != 0; + let has_rst_flag = (request.get_flags() & TCP_RST) != 0; + // regularly update window size + self.window_size = request.sliding_window.val(); + let mut response = self.session_template.clone(); + // Setting length to be 24 [sizeof(TCP-header) + sizeof(options)] + 20 [sizeof(IP-header)] + response.ip.total_length = + Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); + let has_data = !request.data.is_empty(); + if has_rst_flag { + if self.session_state != TCPSessionState::Waiting + && self.session_state != TCPSessionState::Syncing + { + self.session_state = TCPSessionState::Closed; + } + return (None, has_data); + } else { + match self.session_state { + TCPSessionState::Waiting => { + if has_syn_flag && has_ack_flag { + response.turn_on_flags(TCP_ACK); + if request.ack_num.val() != 1 { + // Incorrect ack num + return (None, false); + } + response.ack_num = Bytefield32::new(request.seq_num.val() + 1); + response.seq_num = Bytefield32::new(45503); + } else if has_syn_flag { + response.turn_on_flags(TCP_SYN | TCP_ACK); + response.ack_num = Bytefield32::new(request.seq_num.val() + 1); + response.seq_num = Bytefield32::new(45503); + } else { + // No syn packet but the session hasn't been established + // Therefore we drop the packet + return (None, false); + } + self.session_state = TCPSessionState::Syncing; + } + TCPSessionState::Syncing => { + if has_ack_flag { + if request.ack_num.val() != 1 { + // Incorrect ack num + return (None, false); + } + self.session_state = TCPSessionState::Established; + } else { + // Waiting on the ack packet + // Dropping this packet + return (None, false); + } + } + TCPSessionState::Established => {} + TCPSessionState::Closing => { + if has_fin_flag { + self.session_state = TCPSessionState::Closed; + } + // send fin packet + } + TCPSessionState::Closed => { + // shouldn't a closed tcp session be removed? + // todo: define this behavior + } + }; + } + // Calculate checksums + let data = response.serialize(); + let start_tcp = IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize; + let start_ip = EthernetPacket::packet_size() as usize; + response.ip.calculate_checksum(&data[start_ip..start_tcp]); + response.calculate_checksum(&data[start_tcp..]); + (Some(response), has_data) + } +} diff --git a/kernel/src/task/mod.rs b/kernel/src/task/mod.rs index 533dbc6..db1eeaf 100644 --- a/kernel/src/task/mod.rs +++ b/kernel/src/task/mod.rs @@ -7,6 +7,8 @@ pub mod executor; pub mod keyboard; pub mod simple_executor; pub mod udp_echo; +pub mod timeout; +pub mod tcp_echo; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] struct TaskId(u64); diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs new file mode 100644 index 0000000..2fb34c2 --- /dev/null +++ b/kernel/src/task/tcp_echo.rs @@ -0,0 +1,47 @@ +use crate::{ + print, println, network::{socket::{SocketType, Socket}, raw_socket::NetworkErrors}, +}; +use alloc::string::String; + +pub async fn tcp_echo_server() { + let socket_or_err = Socket::open(SocketType::TCP, 6664, 5).await; + match socket_or_err { + Ok(mut socket_gen) => { + // Allow 10 sockets + for _ in 0..10 { + let mut socket = socket_gen.listen().await.unwrap(); + loop { + let data_or_err = socket.read(0).await; + if let Ok(mut data) = data_or_err { + let user_message = String::from_utf8(data.clone()); + match user_message { + Ok(message) => { + print!("[USER] {}", message); + if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { + println!("Closing socket"); + // socket.write(&mut ("Closing socket...".as_bytes().to_vec())); + socket.close(); + break; + } + }, + Err(err) => println!("[USER-ERR] {:?}", err), + } + let res_or_err = socket.write(&mut data); + if let Err(err) = res_or_err { + println!("[ERR] {:?}", err); + break; + } + } else if let Err(err) = data_or_err { + if err == NetworkErrors::Timeout { + println!("[INFO] Socket had a timeout"); + continue; + } + println!("[ERR] {:?}", err); + break; + } + } + } + }, + Err(err) => println!("[ERR] {:?}", err), + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 3cc8739..206a7af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,8 +14,9 @@ fn main() { cmd.arg("-machine").arg("kernel_irqchip=off"); // Virtualizing a network, has an entry point to inject packets into the isolated network // Can send UDP from localhost:5555 to SOUP:5554 + // Also sends TCP from localhost:6666 to SOUP:6664 cmd.arg("-netdev") - .arg("user,id=net0,hostfwd=udp::5555-:5554"); + .arg("user,id=net0,hostfwd=udp::5555-:5554,hostfwd=tcp::6666-:6664"); // Making sure we have the rtl8139 as a hardware resource cmd.arg("-device") .arg("rtl8139,netdev=net0,mac=00:11:22:33:44:55"); From 09feca57955e67584f2249960645855fd3d0959f Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Thu, 16 Nov 2023 21:10:43 -0500 Subject: [PATCH 13/36] Refactoring and TCP --- kernel/src/main.rs | 13 +-- kernel/src/network/arp.rs | 2 +- kernel/src/network/dhcp.rs | 2 +- kernel/src/network/ethernet.rs | 4 +- kernel/src/network/ip.rs | 4 +- kernel/src/network/layer.rs | 15 +++- kernel/src/network/netsync.rs | 7 +- kernel/src/network/processing.rs | 30 ++++--- kernel/src/network/raw_socket.rs | 2 +- kernel/src/network/rtl8139.rs | 17 +++- kernel/src/network/socket.rs | 101 +++++++++++++++++---- kernel/src/network/tcp.rs | 20 ++--- kernel/src/network/tcp_session.rs | 142 ++++++++++++++++++++++++------ kernel/src/network/udp.rs | 2 +- kernel/src/task/tcp_echo.rs | 8 +- kernel/src/task/udp_echo.rs | 4 +- 16 files changed, 273 insertions(+), 100 deletions(-) diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 9542a98..b0cd786 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -45,14 +45,6 @@ pub static BOOTLOADER_CONFIG: BootloaderConfig = { entry_point!(kernel_main, config = &BOOTLOADER_CONFIG); -async fn async_number() -> u32 { - 42 -} - -async fn example_task() { - let number = async_number().await; - println!("async number: {}", number); -} async fn do_init_dhcp() { let status_init_dhcp = init_dhcp(4).await; @@ -100,10 +92,9 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { // Start the processing of pending packets init_process_packet_data(&mut executor); executor.spawn(Task::new(do_init_dhcp())); // not entirely async, will finish before others are run - executor.wait(); - executor.spawn(Task::new(example_task())); + executor.wait(); // todo: fix wait executor.spawn(Task::new(keyboard::print_keypresses())); - executor.spawn(Task::new(udp_echo::udp_echo_server())); + // executor.spawn(Task::new(udp_echo::udp_echo_server())); executor.spawn(Task::new(tcp_echo::tcp_echo_server())); executor.run(); } diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs index 23cb700..57e2396 100644 --- a/kernel/src/network/arp.rs +++ b/kernel/src/network/arp.rs @@ -74,7 +74,7 @@ impl Layer for ArpPacket { packet.recp_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); assert!(i == 28); // Arp packet should be 28 bytes - (packet, i, LayerType::UNDEF) + (packet, i, LayerType::END) } fn serialize(&self) -> Vec { diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 453165a..b06f0e8 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -133,7 +133,7 @@ impl Layer for DHCPPacket { let left_to_parse = packet.udp.length.val() - 308; i += left_to_parse as usize; assert!(i >= 300); // 300 bytes - (packet, i, LayerType::UNDEF) + (packet, i, LayerType::END) } fn serialize(&self) -> alloc::vec::Vec { diff --git a/kernel/src/network/ethernet.rs b/kernel/src/network/ethernet.rs index 1c80526..32a530c 100644 --- a/kernel/src/network/ethernet.rs +++ b/kernel/src/network/ethernet.rs @@ -70,8 +70,8 @@ impl Layer for EthernetPacket { let layer_type = match &packet.packet_type { EthType::Arp => LayerType::ARP, EthType::IPv4 => LayerType::IP, - EthType::RoCE => LayerType::UNDEF, - EthType::Unknown => LayerType::UNDEF, + EthType::RoCE => LayerType::ERR, + EthType::Unknown => LayerType::ERR, }; (packet, i, layer_type) } diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index b1c4629..df63b3d 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -130,8 +130,8 @@ impl Layer for IPPacket { Protocol::ICMP => LayerType::ICMP, Protocol::TCP => LayerType::TCP, Protocol::UDP => LayerType::UDP, - Protocol::RDP => LayerType::UNDEF, - Protocol::Unsupported => LayerType::UNDEF, + Protocol::RDP => LayerType::ERR, + Protocol::Unsupported => LayerType::ERR, }; (packet, i, layer_type) } diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index 90f4998..95e2f7a 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -36,7 +36,7 @@ impl Layer for EmptyLayer { where Self: Sized, { - (Self {}, 0, LayerType::UNDEF) + (Self {}, 0, LayerType::END) } fn serialize(&self) -> Vec { @@ -62,10 +62,12 @@ pub enum LayerType { ICMP, DHCP, TCP, - UNDEF, // the default layer type + ERR, + END, // the default layer type } /// Wrapper type to allow me to return a generic +/// todo: reduce size of enum #[derive(Debug)] pub enum PacketData { ETH(EthernetPacket), @@ -75,6 +77,7 @@ pub enum PacketData { ICMP(EmptyLayer), DHCP(DHCPPacket), TCP(TCPPacket), + ERR(EmptyLayer), UNDEF(EmptyLayer), } @@ -130,7 +133,8 @@ impl PacketData { PacketData::ICMP(_) => LayerType::ICMP, PacketData::DHCP(_) => LayerType::DHCP, PacketData::TCP(_) => LayerType::TCP, - PacketData::UNDEF(_) => LayerType::UNDEF, + PacketData::ERR(_) => LayerType::ERR, + PacketData::UNDEF(_) => LayerType::END, } } } @@ -192,7 +196,10 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { i += size; next_type = empty_type; }, - LayerType::UNDEF => { + LayerType::ERR => { + return (0, PacketData::ERR(EmptyLayer::new())); + } + LayerType::END => { return (i, last_layer); } } diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs index c8b6c9c..0f88bfe 100644 --- a/kernel/src/network/netsync.rs +++ b/kernel/src/network/netsync.rs @@ -1,6 +1,6 @@ use spin::MutexGuard; -use super::rtl8139::RTL8139; +use super::rtl8139::{RTL8139, NetworkConfig}; struct InterruptGuard { @@ -31,11 +31,12 @@ impl Drop for NetworkInterruptsGuard<'_> { pub struct SafeRTL8139 { data: spin::Mutex>, + pub config: spin::Mutex, } impl SafeRTL8139 { - pub fn new(data: spin::Mutex>) -> Self { - Self { data } + pub fn new(data: spin::Mutex>, config: spin::Mutex) -> Self { + Self { data, config } } pub fn lock(&self) -> NetworkInterruptsGuard { diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 24c1626..9883e64 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -10,7 +10,7 @@ use crate::{ constants::{ARP_PORT, BROADCAST_ADDR, TCP_SYN}, ethernet::{EthType, EthernetPacket}, ip::{IPPacket, Protocol}, - layer::{Layer, PacketData, LayerType}, + layer::{Layer, LayerType, PacketData}, raw_socket::wake_sockets, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, tcp::TCPPacket, @@ -72,8 +72,9 @@ pub async fn init_packet_processing() { let mut raw_packets = PendingProcessingStream::new(); while let Some(pkt_data) = raw_packets.next().await { let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); - if amount_parsed_and_pkt.1.get_type() == LayerType::UNDEF - || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP { + if amount_parsed_and_pkt.1.get_type() == LayerType::ERR + || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP + { // Don't deal with unrecognized packets (will fail assert) continue; } @@ -83,11 +84,11 @@ pub async fn init_packet_processing() { let mut net_dev = NET_INFO.lock(); // Get the device fields let rtl_dev_info = net_dev.get_mut().unwrap(); + let mut rtl_dev_config = NET_INFO.config.lock(); match amount_parsed_and_pkt.1 { PacketData::ARP(arp) => { // todo: also check for broadcast if arp.recp_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { - // println!("[INT-HANDLER] Send a response back"); let eth_layer = EthernetPacket::gen( arp.sender_mac.val(), rtl_dev_info.mac_address.unwrap(), @@ -155,7 +156,6 @@ pub async fn init_packet_processing() { } } PacketData::TCP(tcp) => { - println!("[INFO] Got tcp packet"); let dst_port = tcp.dest_port.val() as u64; if rtl_dev_info.open_ports.contains(&dst_port) { // if we are listening on the port, try to insert it into the map @@ -168,7 +168,7 @@ pub async fn init_packet_processing() { tcp.dest_port.val(), ); // Open up a session - if !rtl_dev_info.tcp_sessions.contains_key(&session_key) { + if !rtl_dev_config.tcp_sessions.contains_key(&session_key) { if (tcp.get_flags() & TCP_SYN) == 0 { // Ignore requests when there is no request for syncing enable_network_interrupts(); @@ -202,16 +202,16 @@ pub async fn init_packet_processing() { .unwrap() .push_back(PacketData::TCP(tcp.clone())); wake_sockets(dst_port); - rtl_dev_info + rtl_dev_config .tcp_sessions .insert(session.session_key(), session); } // todo: SYN COOKIES // todo: Should I have a buffer limit? -- // ? I think maybe no, because upstream data should be prioritized -- - let tcp_session = rtl_dev_info.tcp_sessions.get_mut(&session_key).unwrap(); + let tcp_session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); // Generate an acknowledgement and determine if the tcp packet has data - let ack_pkt = tcp_session.gen_acknowledgement(&tcp); + let ack_pkt = tcp_session.process_recv(&tcp); if let Some(response) = ack_pkt.0 { // If we got a response packet to send back // todo: what happens if our response is dropped... we need to re-ack? @@ -220,16 +220,19 @@ pub async fn init_packet_processing() { // we could also get duplicate data tho? so we need to identify this case rtl_dev_info.send_packet(&response.serialize()); } - println!("[INFO] Packet has data {}", ack_pkt.1); if ack_pkt.1 { - // Push the data to the application socket + // Push the packet to the raw socket --> it will handle its data, if present if rtl_dev_info.open_ports.contains(&session_key) { // if we are listening on the session, try to insert it into the map if !rtl_dev_info.ports.contains_key(&session_key) { - rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + rtl_dev_info.ports.insert(session_key, VecDeque::new()); } // Insert the data and wake the socket - rtl_dev_info.ports.get_mut(&session_key).unwrap().push_back(PacketData::TCP(tcp)); + rtl_dev_info + .ports + .get_mut(&session_key) + .unwrap() + .push_back(PacketData::TCP(tcp)); wake_sockets(session_key); } } @@ -237,6 +240,7 @@ pub async fn init_packet_processing() { } _ => {} // ignore others } + drop(net_dev); enable_network_interrupts(); } } diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index 5b021d5..54e15e7 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -19,7 +19,7 @@ pub enum NetworkErrors { PortInUse, NoAvailablePort, NonexistentHost, - SocketInServerMode, + BadSocketState, FeatureNotAvailableYet, Timeout, } diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index e380d76..0ad5a19 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -58,7 +58,7 @@ lazy_static! { // ! The safertl unwraps the locking mechanism with a RAII enable and disable (Well it would if I could get it to work) pub static ref NET_INFO: SafeRTL8139 = { let devices = devices::scan_devices(); - SafeRTL8139::new(spin::Mutex::new(RTL8139::new(devices))) + SafeRTL8139::new(spin::Mutex::new(RTL8139::new(devices)), spin::Mutex::new(NetworkConfig::new())) }; } @@ -195,11 +195,23 @@ pub struct RTL8139 { pub dhcp_server_ip: Option, pub mac_address: Option, pub open_ports: HashSet, - pub tcp_sessions: HashMap, pub ports: HashMap>, pub arp_table: Vec, } +pub struct NetworkConfig { + pub tcp_sessions: HashMap, +} + +impl NetworkConfig { + pub fn new() -> Self { + NetworkConfig { + tcp_sessions: HashMap::with_capacity(10), + } + } + +} + impl RTL8139 { // Initialize the card pub fn init( @@ -272,7 +284,6 @@ impl RTL8139 { open_ports: HashSet::with_capacity(10), ports: HashMap::with_capacity(10), arp_table: Vec::new(), - tcp_sessions: HashMap::with_capacity(10), }); } None diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 7eb427f..1ba715f 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -4,7 +4,10 @@ use alloc::vec; use alloc::vec::Vec; use futures_util::StreamExt; -use crate::{network::layer::LayerType, println}; +use crate::{ + network::{constants::TCP_ACK, layer::LayerType, tcp::TCPPacket}, + println, +}; use super::{ arp::ArpPacket, @@ -14,7 +17,8 @@ use super::{ layer::{HasChecksum, Layer, PacketData}, raw_socket::{NetworkErrors, RawSocket}, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, - udp::UDPPacket, tcp_session::TCPSession, + tcp_session::TCPSession, + udp::UDPPacket, }; pub struct NetworkQuery {} @@ -54,9 +58,9 @@ impl NetworkQuery { } else { disable_network_interrupts(); { - let rtl_dev_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); - rtl_dev_info.send_packet(&arp_layer.serialize()); + let rtl_dev_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); + rtl_dev_info.send_packet(&arp_layer.serialize()); } enable_network_interrupts(); } @@ -92,11 +96,16 @@ pub struct Socket { src_address: u32, src_mac: u64, wait_timeout: u16, + session_key: u64, } impl Socket { - // Can't send yet - pub async fn open(socket_type: SocketType, src_port: u16, wait_timeout: u16) -> Result { + // Can't send yet -- must listen + pub async fn open( + socket_type: SocketType, + src_port: u16, + wait_timeout: u16, + ) -> Result { let mut chosen_src_port = src_port; disable_network_interrupts(); let rtl_dev_info_locked = NET_INFO.lock(); @@ -131,6 +140,7 @@ impl Socket { src_address, src_mac, wait_timeout, + session_key: 0, }), Err(err) => Err(err), } @@ -163,8 +173,10 @@ impl Socket { let tcp_pkt = pkt.unwrap_tcp(); let dest_address = tcp_pkt.ip.source_ip.val(); let dest_port = tcp_pkt.src_port.val(); - let session_key = TCPSession::gen_session_key(dest_address, dest_port, self.src_port); - let raw_socket = RawSocket::new(session_key, max(self.wait_timeout * 18, 1)).unwrap(); + let session_key = + TCPSession::gen_session_key(dest_address, dest_port, self.src_port); + let raw_socket = + RawSocket::new(session_key, max(self.wait_timeout * 18, 1)).unwrap(); println!("[INFO] Spawned new TCP session"); return Some(Socket { socket_type: SocketType::TCP, @@ -177,8 +189,13 @@ impl Socket { src_address: self.src_address, src_mac: self.src_mac, wait_timeout: self.wait_timeout, + session_key, }); + } else { + // println!("[DEBUG] not useful packet"); } + } else { + // println!("[DEBUG] Got none"); } } } @@ -228,6 +245,7 @@ impl Socket { src_address, src_mac, wait_timeout, + session_key: 0, // todo: Connect should generate a tcp session... }), Err(err) => Err(err), } @@ -239,7 +257,7 @@ impl Socket { pub async fn read(&mut self, size: usize) -> Result, NetworkErrors> { if self.socket_state == SocketState::Listening { - return Err(NetworkErrors::SocketInServerMode); + return Err(NetworkErrors::BadSocketState); } match self.socket_type { SocketType::UDP => self.read_udp(size).await, @@ -280,22 +298,73 @@ impl Socket { } } - pub fn write(&self, data: &mut Vec) -> Result { + pub async fn write(&mut self, data: &mut Vec) -> Result { match self.socket_type { SocketType::UDP => self.write_udp(data), - SocketType::TCP => self.write_tcp(data), + SocketType::TCP => self.write_tcp(data).await, } } - fn write_tcp(&self, data: &mut Vec) -> Result { - let _ = data; - Err(NetworkErrors::FeatureNotAvailableYet) + async fn write_tcp(&mut self, data: &mut Vec) -> Result { + assert!(self.socket_type == SocketType::TCP); + if self.socket_state == SocketState::Listening { + return Err(NetworkErrors::BadSocketState); + } + + // Set up to receive ack + let mut tcp_session_guard = NET_INFO.config.lock(); + let tcp_session = tcp_session_guard + .tcp_sessions + .get_mut(&self.session_key) + .unwrap(); + let message_pkt = tcp_session.process_send(data); + if let Err(err) = message_pkt { + return Err(err); + } + let message = message_pkt.unwrap().serialize(); // todo: error check + drop(tcp_session_guard); + + // Wait for the ack + for retries in 1..21 { + // Send packet + disable_network_interrupts(); + let mut rtl_dev_info_locked = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); + rtl_dev_info.send_packet(&message); + drop(rtl_dev_info_locked); + enable_network_interrupts(); + + // Get next packet or timeout + if let Some(pkt) = self.raw_socket.next().await { + println!("Got packet here!"); + if pkt.get_type() == LayerType::TCP { + // todo: better notification system for raw_socket (alternative to .next)... + // let pkt_data = pkt.unwrap_tcp(); + // Check the acknowledgement to make sure everything is acked + let mut tcp_session_guard = NET_INFO.config.lock(); + let tcp_session = tcp_session_guard + .tcp_sessions + .get_mut(&self.session_key) + .unwrap(); + if tcp_session.everything_acked() { + // if so break out of the timeout loop + break; + } + } + } + if retries == 20 { + // We reach too many timeouts so we return an error + return Err(NetworkErrors::Timeout); + } + } + // Return ok + Ok(data.len() as u16) } fn write_udp(&self, data: &mut Vec) -> Result { assert!(self.socket_type == SocketType::UDP); if self.socket_state == SocketState::Listening { - return Err(NetworkErrors::SocketInServerMode); + return Err(NetworkErrors::BadSocketState); } let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); let udp_size = UDPPacket::packet_size() + data.len() as u16; diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index 5ade514..4a672a2 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -1,8 +1,6 @@ use alloc::vec; use alloc::vec::Vec; -use crate::println; - use super::{ bytefield::{Bytefield16, Bytefield32}, ip::IPPacket, @@ -48,18 +46,18 @@ impl TCPPacket { dest_port: Bytefield16::new(dest_port), seq_num: Bytefield32::new(0), ack_num: Bytefield32::new(0), - flags: Bytefield16::new(6 << 12), // a value of 6 in the header_offset (6*4 = 24 bits -> b/c no options) + flags: Bytefield16::new(5 << 12), // a value of 5 in the header_offset (5*4 = 20 bits -> b/c no options) sliding_window: Bytefield16::new(u16::MAX), // sliding window :(, pain to implement... we can just allow unlimited data but thats insecure... Also we would like to keep track of how much we are allowed to send checksum: Bytefield16::new(0), urgent: Bytefield16::new(0), - options: vec![0x02, 0x04, 0x05, 0xb4], // need to get to 24 + options: vec![], // todo: we set the mss here manually -- fix data: vec![], } } pub fn options_size() -> u16 { - 4 + 0 } // N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) @@ -112,19 +110,19 @@ impl Layer for TCPPacket { packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.urgent = Bytefield16::read_inc(&bytevec[i..], &mut i); // read remaining bytes and place them into the data buffer - println!("[INFO] Packet header offset {}", packet.get_header_offset()); - println!("[INFO] Buffer is {}", bytevec.len()); for _ in 0..(packet.get_header_offset() - 20) { packet.options.push(bytevec[i]); i += 1; } assert!(i == packet.get_header_offset().into()); // Valid i is header_length - let data_size = packet.ip.total_length.val() - packet.get_header_offset() as u16 - IPPacket::packet_size(); + let data_size = packet.ip.total_length.val() + - packet.get_header_offset() as u16 + - IPPacket::packet_size(); for _ in 0..data_size { - packet.options.push(bytevec[i]); + packet.data.push(bytevec[i]); i += 1; } - (packet, i, LayerType::UNDEF) + (packet, i, LayerType::END) } fn serialize(&self) -> alloc::vec::Vec { @@ -168,7 +166,7 @@ impl HasChecksum for TCPPacket { // Sum protocol and length let protocol = Bytefield16::new(ip.protocol as u16); sum += (protocol.data[0] as u32) | (protocol.data[1] as u32) << 8; - let segment_size = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size()); + let segment_size = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + self.data.len() as u16); sum += (segment_size.data[0] as u32) | (segment_size.data[1] as u32) << 8; // Zero the checksum field diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index a691bd6..8c30fbb 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -1,10 +1,14 @@ +use alloc::vec::Vec; + +use crate::println; + use super::{ bytefield::{Bytefield16, Bytefield32}, - constants::{TCP_ACK, TCP_FIN, TCP_RST, TCP_SYN}, + constants::{TCP_ACK, TCP_FIN, TCP_RST, TCP_SYN, TCP_PSH}, ethernet::EthernetPacket, ip::IPPacket, layer::{HasChecksum, Layer}, - tcp::TCPPacket, + tcp::TCPPacket, raw_socket::NetworkErrors, }; #[derive(Debug, PartialEq, Eq)] @@ -21,9 +25,16 @@ pub struct TCPSession { pub dest_ip: u32, pub dest_port: u16, pub src_port: u16, - pub send_acked_up_to: u32, - pub recv_acked_up_to: u32, + /// our seq num + pub sent_data_amount: u32, + /// how much the client has acked + pub sent_data_acked: u32, + /// our ack num + pub recv_data_amount: u32, + // pub recv_data_acked: u32, implicit how much we have acked -- this value is unknown to use pub window_size: u16, + i_am_closing: bool, + received_ack_to_fin: bool, session_state: TCPSessionState, } @@ -31,12 +42,15 @@ impl TCPSession { pub fn new(session_template: TCPPacket, dest_ip: u32, dest_port: u16, src_port: u16) -> Self { TCPSession { session_template, - send_acked_up_to: 0, - recv_acked_up_to: 0, + sent_data_amount: 55304, + sent_data_acked: 55304, + recv_data_amount: 0, dest_ip, dest_port, src_port, window_size: u16::MAX, + i_am_closing: false, + received_ack_to_fin: false, session_state: TCPSessionState::Waiting, } } @@ -55,7 +69,7 @@ impl TCPSession { } pub fn close(&mut self) { - self.session_state = TCPSessionState::Closing; + // self.session_state = TCPSessionState::Closing; // get rtl object // transition into } @@ -65,40 +79,78 @@ impl TCPSession { // send a rst and immediately stop receiving } - pub fn gen_acknowledgement(&mut self, request: &TCPPacket) -> (Option, bool) { + pub fn everything_acked(&self) -> bool { + self.sent_data_acked == self.sent_data_amount + } + + pub fn process_send(&mut self, data: &Vec) -> Result { + if self.session_state != TCPSessionState::Established { + return Err(NetworkErrors::BadSocketState); + } + let mut tcp_pkt = self.session_template.clone(); + tcp_pkt.turn_on_flags(TCP_ACK | TCP_PSH); + tcp_pkt.data = data.to_vec(); // todo: split the data and return the amount actually written! + + // add the data size (maybe make this automatic?) + tcp_pkt.ip.total_length = Bytefield16::new( + TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size() + data.len() as u16, + ); + tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); + self.sent_data_amount += data.len() as u32; + tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); + + // Calculate checksums + let data = tcp_pkt.serialize(); + let start_tcp = IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize; + let start_ip = EthernetPacket::packet_size() as usize; + tcp_pkt.ip.calculate_checksum(&data[start_ip..start_tcp]); + tcp_pkt.calculate_checksum(&data[start_tcp..]); + + Ok(tcp_pkt) + } + + pub fn process_recv(&mut self, request: &TCPPacket) -> (Option, bool) { let has_syn_flag = (request.get_flags() & TCP_SYN) != 0; let has_ack_flag = (request.get_flags() & TCP_ACK) != 0; let has_fin_flag = (request.get_flags() & TCP_FIN) != 0; let has_rst_flag = (request.get_flags() & TCP_RST) != 0; + // println!("[TCP {:?}] Req has Syn={} Ack={} Fin={} Rst={}", self.session_state, has_syn_flag, has_ack_flag, has_fin_flag, has_rst_flag); // regularly update window size self.window_size = request.sliding_window.val(); let mut response = self.session_template.clone(); // Setting length to be 24 [sizeof(TCP-header) + sizeof(options)] + 20 [sizeof(IP-header)] - response.ip.total_length = - Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); - let has_data = !request.data.is_empty(); + response.ip.total_length = Bytefield16::new( + TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size(), + ); + if has_fin_flag { + // todo! check state here + self.session_state = TCPSessionState::Closing; + } if has_rst_flag { - if self.session_state != TCPSessionState::Waiting - && self.session_state != TCPSessionState::Syncing - { + println!("[TCP] Immediately dying"); + if self.session_state == TCPSessionState::Established { self.session_state = TCPSessionState::Closed; } - return (None, has_data); + return (None, false); } else { match self.session_state { TCPSessionState::Waiting => { if has_syn_flag && has_ack_flag { response.turn_on_flags(TCP_ACK); - if request.ack_num.val() != 1 { + if request.ack_num.val() != self.sent_data_amount + 1 { // Incorrect ack num return (None, false); } - response.ack_num = Bytefield32::new(request.seq_num.val() + 1); - response.seq_num = Bytefield32::new(45503); + self.sent_data_amount += 1; + self.sent_data_acked += 1; + self.recv_data_amount = request.seq_num.val() + 1; + response.ack_num = Bytefield32::new(self.recv_data_amount); + response.seq_num = Bytefield32::new(self.sent_data_amount); } else if has_syn_flag { response.turn_on_flags(TCP_SYN | TCP_ACK); - response.ack_num = Bytefield32::new(request.seq_num.val() + 1); - response.seq_num = Bytefield32::new(45503); + self.recv_data_amount = request.seq_num.val() + 1; + response.ack_num = Bytefield32::new(self.recv_data_amount); + response.seq_num = Bytefield32::new(self.sent_data_amount); } else { // No syn packet but the session hasn't been established // Therefore we drop the packet @@ -108,25 +160,65 @@ impl TCPSession { } TCPSessionState::Syncing => { if has_ack_flag { - if request.ack_num.val() != 1 { + if request.ack_num.val() != self.sent_data_amount + 1 { // Incorrect ack num return (None, false); } + self.sent_data_amount += 1; + self.sent_data_acked += 1; self.session_state = TCPSessionState::Established; + // Has no need for a response + return (None, true); } else { // Waiting on the ack packet // Dropping this packet return (None, false); } } - TCPSessionState::Established => {} + TCPSessionState::Established => { + let mut has_info = false; + if request.ack_num.val() > self.sent_data_acked && request.ack_num.val() <= self.sent_data_amount { + // We are updating our record of how much data was acked + self.sent_data_acked = request.ack_num.val(); + has_info = true; + } + if request.seq_num.val() == self.recv_data_amount && !request.data.is_empty() { + // Sequence number matches and we have data (So we need to ack) + self.recv_data_amount += request.data.len() as u32; + response.ack_num = Bytefield32::new(self.recv_data_amount); + response.seq_num = Bytefield32::new(self.sent_data_amount); + response.turn_on_flags(TCP_ACK); + } else { + return (None, has_info); + } + } TCPSessionState::Closing => { - if has_fin_flag { + if self.i_am_closing { + // if has_fin_flag && has_ack_flag { + // response.turn_on_flags(TCP_ACK); + // if self.received_ack_to_fin { + // self.session_state = TCPSessionState::Closed; + // } + // self.received_ack_to_fin = true; + // } + // if has_ack_flag { + // if self.received_ack_to_fin { + // self.session_state = TCPSessionState::Closed; + // } + // self.received_ack_to_fin = true; + // return (None, false); + // } + } else { + response.turn_on_flags(TCP_FIN | TCP_ACK); self.session_state = TCPSessionState::Closed; } - // send fin packet + self.recv_data_amount += 1; + response.ack_num = Bytefield32::new(self.recv_data_amount); + response.seq_num = Bytefield32::new(self.sent_data_amount); } TCPSessionState::Closed => { + // drop all packets -- session is closed + return (None, false); // shouldn't a closed tcp session be removed? // todo: define this behavior } @@ -138,6 +230,6 @@ impl TCPSession { let start_ip = EthernetPacket::packet_size() as usize; response.ip.calculate_checksum(&data[start_ip..start_tcp]); response.calculate_checksum(&data[start_tcp..]); - (Some(response), has_data) + (Some(response), true) } } diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index 23c12c0..e35000b 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -66,7 +66,7 @@ impl Layer for UDPPacket { i += 1; } assert!(i == packet.length.val() as usize); - LayerType::UNDEF + LayerType::END } }; (packet, i, layer_type) diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index 2fb34c2..fab5ef5 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -4,7 +4,7 @@ use crate::{ use alloc::string::String; pub async fn tcp_echo_server() { - let socket_or_err = Socket::open(SocketType::TCP, 6664, 5).await; + let socket_or_err = Socket::open(SocketType::TCP, 6664, 1).await; match socket_or_err { Ok(mut socket_gen) => { // Allow 10 sockets @@ -19,21 +19,21 @@ pub async fn tcp_echo_server() { print!("[USER] {}", message); if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { println!("Closing socket"); - // socket.write(&mut ("Closing socket...".as_bytes().to_vec())); + let mut exit_message = "Closing socket...\n".as_bytes().to_vec(); + (socket.write(&mut exit_message).await).unwrap(); socket.close(); break; } }, Err(err) => println!("[USER-ERR] {:?}", err), } - let res_or_err = socket.write(&mut data); + let res_or_err = socket.write(&mut data).await; if let Err(err) = res_or_err { println!("[ERR] {:?}", err); break; } } else if let Err(err) = data_or_err { if err == NetworkErrors::Timeout { - println!("[INFO] Socket had a timeout"); continue; } println!("[ERR] {:?}", err); diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index 459702a..3134df6 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -18,14 +18,14 @@ pub async fn udp_echo_server() { print!("[USER] {}", message); if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { println!("Closing socket"); - let _ = socket.write(&mut ("Closing socket...".as_bytes().to_vec())); + let _ = socket.write(&mut ("Closing socket...".as_bytes().to_vec())).await; socket.close(); return; } }, Err(err) => println!("[USER-ERR] {:?}", err), } - let res_or_err = socket.write(&mut data); + let res_or_err = socket.write(&mut data).await; if let Err(err) = res_or_err { println!("[ERR] {:?}", err); break; From afd569ecf758f5bd3363d695c5e4bd9d14916bf4 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:23:43 -0500 Subject: [PATCH 14/36] TCP shutdown --- kernel/src/network/README.md | 4 +- kernel/src/network/TODO.md | 1 - kernel/src/network/init.rs | 6 +- kernel/src/network/processing.rs | 162 +++++++++++++++------------- kernel/src/network/raw_socket.rs | 15 ++- kernel/src/network/rtl8139.rs | 3 +- kernel/src/network/socket.rs | 90 +++++++++++++--- kernel/src/network/tcp_session.rs | 173 +++++++++++++++++++----------- kernel/src/task/tcp_echo.rs | 18 +++- 9 files changed, 304 insertions(+), 168 deletions(-) diff --git a/kernel/src/network/README.md b/kernel/src/network/README.md index ccbb4fb..cb7362a 100644 --- a/kernel/src/network/README.md +++ b/kernel/src/network/README.md @@ -6,7 +6,7 @@ TODO [x] PCI scanning for devices [x] RTL8139 Driver Code -[x] Ethernet, IP, UDP, ARP, DHCP +[x] Ethernet, IP, UDP, ARP, DHCP, TCP [x] RawSocket API [x] Better Socket API [x] Async IO @@ -17,8 +17,6 @@ TODO [] Verify other parts of the packet [] Fix synchronization to be much cleaner [] Clean up ugly stuff -[] DHCP parse additional options -[] TCP [] Refactor to be all constants [] search for todo and fix thoses diff --git a/kernel/src/network/TODO.md b/kernel/src/network/TODO.md index 5207120..8fd86bd 100644 --- a/kernel/src/network/TODO.md +++ b/kernel/src/network/TODO.md @@ -8,7 +8,6 @@ * Fix checksums to be baked-in * Clean up ugly stuff * DHCP parse additional options -* TCP * Refactor to be all constants * Search for todo and fix thoses * Benchmarking diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index f2b7a6f..9f15275 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -51,7 +51,11 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { let pkt_data; loop { if let Some(dhcp_res) = socket.next().await { - pkt_data = dhcp_res; + if dhcp_res.is_err() { + // If we got a socket error, we must return false + return false; + } + pkt_data = dhcp_res.unwrap(); break; } else { disable_network_interrupts(); diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 9883e64..d1c7895 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -7,14 +7,14 @@ use crate::{ network::{ arp::ArpPacket, arp_table::ArpEntry, - constants::{ARP_PORT, BROADCAST_ADDR, TCP_SYN}, + constants::{ARP_PORT, BROADCAST_ADDR, TCP_SYN, TCP_FIN, TCP_ACK}, ethernet::{EthType, EthernetPacket}, ip::{IPPacket, Protocol}, layer::{Layer, LayerType, PacketData}, - raw_socket::wake_sockets, + raw_socket::{wake_sockets, NetworkErrors}, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, tcp::TCPPacket, - tcp_session::TCPSession, + tcp_session::{TCPSession, SessionAction}, }, println, }; @@ -118,7 +118,7 @@ pub async fn init_packet_processing() { .ports .get_mut(&ARP_PORT) .unwrap() - .push_back(PacketData::ARP(arp)); + .push_back(Ok(PacketData::ARP(arp))); wake_sockets(ARP_PORT); } } @@ -136,7 +136,7 @@ pub async fn init_packet_processing() { .ports .get_mut(&dst_port) .unwrap() - .push_back(PacketData::DHCP(dhcp)); + .push_back(Ok(PacketData::DHCP(dhcp))); wake_sockets(dst_port); } } @@ -151,92 +151,104 @@ pub async fn init_packet_processing() { .ports .get_mut(&dst_port) .unwrap() - .push_back(PacketData::UDP(udp)); + .push_back(Ok(PacketData::UDP(udp))); wake_sockets(dst_port); } } PacketData::TCP(tcp) => { let dst_port = tcp.dest_port.val() as u64; - if rtl_dev_info.open_ports.contains(&dst_port) { - // if we are listening on the port, try to insert it into the map - if !rtl_dev_info.ports.contains_key(&dst_port) { - rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + if !rtl_dev_info.open_ports.contains(&dst_port) { + continue; + } + // if we are listening on the port, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&dst_port) { + rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + } + let session_key = TCPSession::gen_session_key( + tcp.ip.source_ip.val(), + tcp.src_port.val(), + tcp.dest_port.val(), + ); + // Open up a session + if !rtl_dev_config.tcp_sessions.contains_key(&session_key) { + if (tcp.get_flags() & TCP_SYN) == 0 { + // Ignore requests when there is no request for syncing + enable_network_interrupts(); + return; } - let session_key = TCPSession::gen_session_key( + // Compact the first packet we receive as the session creation + let eth_layer = EthernetPacket::gen( + tcp.ip.eth.src_mac.val(), + tcp.ip.eth.dest_mac.val(), + EthType::IPv4, + ); + let ip_layer = IPPacket::gen( + eth_layer, + 0, // leaving size undefined for the template + Protocol::TCP, + tcp.ip.destination_ip.val(), + tcp.ip.source_ip.val(), + ); + let tcp_layer = + TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); + let session = TCPSession::new( + tcp_layer, tcp.ip.source_ip.val(), tcp.src_port.val(), tcp.dest_port.val(), ); - // Open up a session - if !rtl_dev_config.tcp_sessions.contains_key(&session_key) { - if (tcp.get_flags() & TCP_SYN) == 0 { - // Ignore requests when there is no request for syncing - enable_network_interrupts(); - return; + // Lets push to the port -- we are listening then we need to create a new session + rtl_dev_info + .ports + .get_mut(&dst_port) + .unwrap() + .push_back(Ok(PacketData::TCP(tcp.clone()))); + wake_sockets(dst_port); + rtl_dev_config + .tcp_sessions + .insert(session.session_key(), session); + } + // todo: SYN COOKIES? + // todo: Should I have a buffer limit? -- + // ? I think maybe no, because upstream data should be prioritized -- + // ? Otherwise I might not have enough processing speed + let tcp_session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); + println!("Recv packet! Fin:{} Ack:{} State:{:?}", tcp.get_flags() & TCP_FIN != 0, tcp.get_flags() & TCP_ACK != 0, tcp_session.session_state); + + // Generate an acknowledgement and determine if the tcp packet has data + let ack_pkt = tcp_session.process_recv(&tcp); + if let Some(response) = ack_pkt.0 { + // If we got a response packet to send back + // todo: what happens if our response is dropped... we need to re-ack? + // generally if no ack is receeived, the host will send another transmission + // ALSO we don't know if our ack was received or not, so we just wait for another transmission + // we could also get duplicate data tho? so we need to identify this case + rtl_dev_info.send_packet(&response.serialize()); + } + if ack_pkt.1 != SessionAction::Drop { + // Push the packet to the raw socket --> it will handle its data, if present + if rtl_dev_info.open_ports.contains(&session_key) { + // if we are listening on the session, try to insert it into the map + if !rtl_dev_info.ports.contains_key(&session_key) { + rtl_dev_info.ports.insert(session_key, VecDeque::new()); } - // Compact the first packet we receive as the session creation - let eth_layer = EthernetPacket::gen( - tcp.ip.eth.src_mac.val(), - tcp.ip.eth.dest_mac.val(), - EthType::IPv4, - ); - let ip_layer = IPPacket::gen( - eth_layer, - 0, // leaving size undefined for the template - Protocol::TCP, - tcp.ip.destination_ip.val(), - tcp.ip.source_ip.val(), - ); - let tcp_layer = - TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); - let session = TCPSession::new( - tcp_layer, - tcp.ip.source_ip.val(), - tcp.src_port.val(), - tcp.dest_port.val(), - ); - // Lets push to the port -- we are listening then we need to create a new session + // Insert the data and wake the socket + let res = if ack_pkt.1 == SessionAction::PushUpstream { + Ok(PacketData::TCP(tcp)) + } else if ack_pkt.1 == SessionAction::EndOfStream { + Err(NetworkErrors::ClosedSocket) + } else { + unreachable!(); + }; rtl_dev_info .ports - .get_mut(&dst_port) + .get_mut(&session_key) .unwrap() - .push_back(PacketData::TCP(tcp.clone())); - wake_sockets(dst_port); - rtl_dev_config - .tcp_sessions - .insert(session.session_key(), session); - } - // todo: SYN COOKIES - // todo: Should I have a buffer limit? -- - // ? I think maybe no, because upstream data should be prioritized -- - let tcp_session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); - // Generate an acknowledgement and determine if the tcp packet has data - let ack_pkt = tcp_session.process_recv(&tcp); - if let Some(response) = ack_pkt.0 { - // If we got a response packet to send back - // todo: what happens if our response is dropped... we need to re-ack? - // generally if no ack is receeived, the host will send another transmission - // ALSO we don't know if our ack was received or not, so we just wait for another transmission - // we could also get duplicate data tho? so we need to identify this case - rtl_dev_info.send_packet(&response.serialize()); - } - if ack_pkt.1 { - // Push the packet to the raw socket --> it will handle its data, if present - if rtl_dev_info.open_ports.contains(&session_key) { - // if we are listening on the session, try to insert it into the map - if !rtl_dev_info.ports.contains_key(&session_key) { - rtl_dev_info.ports.insert(session_key, VecDeque::new()); - } - // Insert the data and wake the socket - rtl_dev_info - .ports - .get_mut(&session_key) - .unwrap() - .push_back(PacketData::TCP(tcp)); - wake_sockets(session_key); - } + .push_back(res); + wake_sockets(session_key); } } + // todo: Push end of stream when necessary } _ => {} // ignore others } diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index 54e15e7..842f07d 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -22,6 +22,8 @@ pub enum NetworkErrors { BadSocketState, FeatureNotAvailableYet, Timeout, + /// This is a special network error for when our TCP stream has closed + ClosedSocket, } pub struct RawSocket { @@ -49,7 +51,7 @@ impl RawSocket { Ok(RawSocket { port, timeout_in_epochs, timeout_active: false, timeout_id: TimeoutID::new(), }) } - fn try_get_packet_inner(&self) -> Option { + fn try_get_packet_inner(&self) -> Option> { let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); match rtl_dev_info.ports.get_mut(&self.port) { @@ -71,18 +73,15 @@ impl RawSocket { let vec = rtl_dev_info.ports.get_mut(&self.port); vec.unwrap().clear(); } - // remove tcp session information - // todo: what about udp? what about sending a FIN - /*let session_key = TCPSession::gen_session_key(self.dest_ip, self.dest_port); - if rtl_dev_info.tcp_sessions.contains_key(&session_key) { - rtl_dev_info.tcp_sessions.remove(&session_key); - }*/ + // tcp session information is removed by the socket (not the raw socket) + // since the socket is typed (tcp or udp) and the raw_socket purpose is just + // for polling for packets enable_network_interrupts(); } } impl Stream for RawSocket { - type Item = PacketData; + type Item = Result; fn poll_next(mut self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { disable_network_interrupts(); diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 0ad5a19..2ce871c 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -25,6 +25,7 @@ use super::constants::{ INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG, }; use super::processing::add_pkt_data; +use super::raw_socket::NetworkErrors; use super::tcp_session::TCPSession; use super::{ arp_table::ArpEntry, @@ -195,7 +196,7 @@ pub struct RTL8139 { pub dhcp_server_ip: Option, pub mac_address: Option, pub open_ports: HashSet, - pub ports: HashMap>, + pub ports: HashMap>>, pub arp_table: Vec, } diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 1ba715f..a1c681a 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -6,7 +6,7 @@ use futures_util::StreamExt; use crate::{ network::{constants::TCP_ACK, layer::LayerType, tcp::TCPPacket}, - println, + println, print, }; use super::{ @@ -50,10 +50,14 @@ impl NetworkQuery { let mut timeout = 0; loop { if let Some(pkt) = socket.next().await { - if pkt.get_type() != LayerType::ARP { + if pkt.is_err() { continue; } - let arp_pkt = pkt.unwrap_arp(); + let pkt_data = pkt.unwrap(); + if pkt_data.get_type() != LayerType::ARP { + continue; + } + let arp_pkt = pkt_data.unwrap_arp(); return Some(arp_pkt.sender_mac.val()); } else { disable_network_interrupts(); @@ -83,6 +87,7 @@ pub enum SocketType { enum SocketState { Listening, Ready, + Closed, } pub struct Socket { @@ -153,7 +158,12 @@ impl Socket { return None; // todo: this is failing silently } loop { - if let Some(pkt) = self.raw_socket.next().await { + if let Some(pkt_or_err) = self.raw_socket.next().await { + if pkt_or_err.is_err() { + // todo: this is failing silently + return None; + } + let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::UDP && self.socket_type == SocketType::UDP { let udp_pkt = pkt.unwrap_udp(); self.dest_port = udp_pkt.src_port.val(); @@ -165,7 +175,7 @@ impl Socket { let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); // re-enqueue the packet if let Some(vec) = rtl_dev_info.ports.get_mut(&(self.src_port as u64)) { - vec.push_front(PacketData::UDP(udp_pkt)); + vec.push_front(Ok(PacketData::UDP(udp_pkt))); } enable_network_interrupts(); return None; @@ -251,12 +261,55 @@ impl Socket { } } - pub fn close(self) { + pub async fn close(mut self) { + if self.socket_state == SocketState::Closed { + // Already closed + return; + } + if self.socket_type == SocketType::TCP { + let mut rtl_dev_config = NET_INFO.config.lock(); + let session_key = TCPSession::gen_session_key(self.dest_ip, self.dest_port, self.src_port); + if rtl_dev_config.tcp_sessions.contains_key(&session_key) { + let session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); + if let Ok(pkt) = session.close() { + drop(rtl_dev_config); + // With 5 retries + 'outer: for _ in 0..5 { + // Send the FIN-ACK packet + disable_network_interrupts(); + let mut rtl_dev_info_locked = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); + rtl_dev_info.send_packet(&pkt.serialize()); + drop(rtl_dev_info_locked); + enable_network_interrupts(); + + loop { + // Keep reading the stream + if let Err(next) = self.read_tcp(0).await { + if next == NetworkErrors::ClosedSocket { + // If we have a closed stream, we break out completely + break 'outer; + } + // Likely got a timeout - so retry + break; + } + } + + if NET_INFO.config.lock().tcp_sessions.get_mut(&session_key).unwrap().is_closed() { + // If the session is closed, break the outer loop + break 'outer; + } + } + } + // Remove the session + NET_INFO.config.lock().tcp_sessions.remove(&session_key); + } + } self.raw_socket.close(); } pub async fn read(&mut self, size: usize) -> Result, NetworkErrors> { - if self.socket_state == SocketState::Listening { + if self.socket_state != SocketState::Ready { return Err(NetworkErrors::BadSocketState); } match self.socket_type { @@ -267,7 +320,11 @@ impl Socket { async fn read_udp(&mut self, size: usize) -> Result, NetworkErrors> { loop { - if let Some(pkt) = self.raw_socket.next().await { + if let Some(pkt_or_err) = self.raw_socket.next().await { + if let Err(err) = pkt_or_err { + return Err(err); + } + let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::UDP { let udp_pkt = pkt.unwrap_udp(); return Ok(udp_pkt.data); @@ -281,7 +338,11 @@ impl Socket { async fn read_tcp(&mut self, size: usize) -> Result, NetworkErrors> { loop { let mut res_vec = vec![]; - if let Some(pkt) = self.raw_socket.next().await { + if let Some(pkt_or_err) = self.raw_socket.next().await { + if let Err(err) = pkt_or_err { + return Err(err); + } + let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::TCP { let mut tcp_pkt = pkt.unwrap_tcp(); res_vec.append(&mut tcp_pkt.data); @@ -307,7 +368,7 @@ impl Socket { async fn write_tcp(&mut self, data: &mut Vec) -> Result { assert!(self.socket_type == SocketType::TCP); - if self.socket_state == SocketState::Listening { + if self.socket_state != SocketState::Ready { return Err(NetworkErrors::BadSocketState); } @@ -335,8 +396,11 @@ impl Socket { enable_network_interrupts(); // Get next packet or timeout - if let Some(pkt) = self.raw_socket.next().await { - println!("Got packet here!"); + if let Some(pkt_or_err) = self.raw_socket.next().await { + if let Err(err) = pkt_or_err { + return Err(err); + } + let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::TCP { // todo: better notification system for raw_socket (alternative to .next)... // let pkt_data = pkt.unwrap_tcp(); @@ -363,7 +427,7 @@ impl Socket { fn write_udp(&self, data: &mut Vec) -> Result { assert!(self.socket_type == SocketType::UDP); - if self.socket_state == SocketState::Listening { + if self.socket_state != SocketState::Ready { return Err(NetworkErrors::BadSocketState); } let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index 8c30fbb..2b44dd9 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -4,15 +4,16 @@ use crate::println; use super::{ bytefield::{Bytefield16, Bytefield32}, - constants::{TCP_ACK, TCP_FIN, TCP_RST, TCP_SYN, TCP_PSH}, + constants::{TCP_ACK, TCP_FIN, TCP_PSH, TCP_RST, TCP_SYN}, ethernet::EthernetPacket, ip::IPPacket, layer::{HasChecksum, Layer}, - tcp::TCPPacket, raw_socket::NetworkErrors, + raw_socket::NetworkErrors, + tcp::TCPPacket, }; #[derive(Debug, PartialEq, Eq)] -enum TCPSessionState { +pub enum TCPSessionState { Waiting, Syncing, Established, @@ -20,22 +21,31 @@ enum TCPSessionState { Closed, } +#[derive(Debug, PartialEq, Eq)] +pub enum SessionAction { + PushUpstream, + Drop, + EndOfStream, +} + pub struct TCPSession { session_template: TCPPacket, pub dest_ip: u32, pub dest_port: u16, pub src_port: u16, /// our seq num - pub sent_data_amount: u32, + pub sent_data_amount: u32, /// how much the client has acked - pub sent_data_acked: u32, + pub sent_data_acked: u32, /// our ack num - pub recv_data_amount: u32, + pub recv_data_amount: u32, // pub recv_data_acked: u32, implicit how much we have acked -- this value is unknown to use pub window_size: u16, - i_am_closing: bool, - received_ack_to_fin: bool, - session_state: TCPSessionState, + /// If the user has sent fin_ack closing + has_sent_fin_ack: bool, + has_recv_ack_to_fin_ack: bool, + has_sent_ack_to_fin_ack: bool, + pub session_state: TCPSessionState, } impl TCPSession { @@ -49,8 +59,9 @@ impl TCPSession { dest_port, src_port, window_size: u16::MAX, - i_am_closing: false, - received_ack_to_fin: false, + has_sent_fin_ack: false, + has_recv_ack_to_fin_ack: false, + has_sent_ack_to_fin_ack: false, session_state: TCPSessionState::Waiting, } } @@ -68,15 +79,59 @@ impl TCPSession { (dest_ip as u64) << 32 | (dest_port as u64) << 16 | src_port as u64 } - pub fn close(&mut self) { - // self.session_state = TCPSessionState::Closing; - // get rtl object - // transition into + pub fn is_closed(&self) -> bool { + self.session_state ==TCPSessionState::Closed || self.has_recv_ack_to_fin_ack && self.has_sent_ack_to_fin_ack } - pub fn reset(&mut self) { - // get rtl object - // send a rst and immediately stop receiving + /// Create a packet to close the tcp session with + /// Will also transition to the closing state + pub fn close(&mut self) -> Result { + if self.session_state != TCPSessionState::Established { + return Err(NetworkErrors::BadSocketState); + } + // Transition to the closing state + self.session_state = TCPSessionState::Closing; + + self.has_sent_fin_ack = true; + let mut tcp_pkt = self.session_template.clone(); + tcp_pkt.turn_on_flags(TCP_FIN | TCP_ACK); + + // add the data size (maybe make this automatic?) + tcp_pkt.ip.total_length = Bytefield16::new( + TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size(), + ); + tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); + tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); + + // Calculate checksums + let data = tcp_pkt.serialize(); + let start_tcp = IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize; + let start_ip = EthernetPacket::packet_size() as usize; + tcp_pkt.ip.calculate_checksum(&data[start_ip..start_tcp]); + tcp_pkt.calculate_checksum(&data[start_tcp..]); + + Ok(tcp_pkt) + } + + pub fn reset(&mut self) -> TCPPacket { + self.session_state = TCPSessionState::Closed; + let mut tcp_pkt = self.session_template.clone(); + tcp_pkt.turn_on_flags(TCP_RST); + + // add the data size (maybe make this automatic?) + tcp_pkt.ip.total_length = Bytefield16::new( + TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size(), + ); + tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); + tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); + + // Calculate checksums + let data = tcp_pkt.serialize(); + let start_tcp = IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize; + let start_ip = EthernetPacket::packet_size() as usize; + tcp_pkt.ip.calculate_checksum(&data[start_ip..start_tcp]); + tcp_pkt.calculate_checksum(&data[start_tcp..]); + tcp_pkt } pub fn everything_acked(&self) -> bool { @@ -90,10 +145,13 @@ impl TCPSession { let mut tcp_pkt = self.session_template.clone(); tcp_pkt.turn_on_flags(TCP_ACK | TCP_PSH); tcp_pkt.data = data.to_vec(); // todo: split the data and return the amount actually written! - + // add the data size (maybe make this automatic?) tcp_pkt.ip.total_length = Bytefield16::new( - TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size() + data.len() as u16, + TCPPacket::packet_size() + + TCPPacket::options_size() + + IPPacket::packet_size() + + data.len() as u16, ); tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); self.sent_data_amount += data.len() as u32; @@ -109,29 +167,26 @@ impl TCPSession { Ok(tcp_pkt) } - pub fn process_recv(&mut self, request: &TCPPacket) -> (Option, bool) { + pub fn process_recv(&mut self, request: &TCPPacket) -> (Option, SessionAction) { let has_syn_flag = (request.get_flags() & TCP_SYN) != 0; let has_ack_flag = (request.get_flags() & TCP_ACK) != 0; let has_fin_flag = (request.get_flags() & TCP_FIN) != 0; let has_rst_flag = (request.get_flags() & TCP_RST) != 0; - // println!("[TCP {:?}] Req has Syn={} Ack={} Fin={} Rst={}", self.session_state, has_syn_flag, has_ack_flag, has_fin_flag, has_rst_flag); - // regularly update window size + // todo: regularly update window size self.window_size = request.sliding_window.val(); let mut response = self.session_template.clone(); // Setting length to be 24 [sizeof(TCP-header) + sizeof(options)] + 20 [sizeof(IP-header)] response.ip.total_length = Bytefield16::new( TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size(), ); - if has_fin_flag { - // todo! check state here + // todo! check seq-num and ack-num before doing these state changes + if has_fin_flag && self.session_state == TCPSessionState::Established { self.session_state = TCPSessionState::Closing; } if has_rst_flag { - println!("[TCP] Immediately dying"); - if self.session_state == TCPSessionState::Established { - self.session_state = TCPSessionState::Closed; - } - return (None, false); + println!("[TCP] Received RST -- dying {:?}", self.session_state); + self.session_state = TCPSessionState::Closed; + return (None, SessionAction::Drop); } else { match self.session_state { TCPSessionState::Waiting => { @@ -139,7 +194,7 @@ impl TCPSession { response.turn_on_flags(TCP_ACK); if request.ack_num.val() != self.sent_data_amount + 1 { // Incorrect ack num - return (None, false); + return (None, SessionAction::Drop); } self.sent_data_amount += 1; self.sent_data_acked += 1; @@ -154,7 +209,7 @@ impl TCPSession { } else { // No syn packet but the session hasn't been established // Therefore we drop the packet - return (None, false); + return (None, SessionAction::Drop); } self.session_state = TCPSessionState::Syncing; } @@ -162,25 +217,26 @@ impl TCPSession { if has_ack_flag { if request.ack_num.val() != self.sent_data_amount + 1 { // Incorrect ack num - return (None, false); + return (None, SessionAction::Drop); } self.sent_data_amount += 1; self.sent_data_acked += 1; self.session_state = TCPSessionState::Established; // Has no need for a response - return (None, true); + return (None, SessionAction::PushUpstream); } else { // Waiting on the ack packet // Dropping this packet - return (None, false); + return (None, SessionAction::Drop); } } TCPSessionState::Established => { - let mut has_info = false; - if request.ack_num.val() > self.sent_data_acked && request.ack_num.val() <= self.sent_data_amount { + let mut has_info = SessionAction::Drop; + if request.ack_num.val() > self.sent_data_acked + && request.ack_num.val() <= self.sent_data_amount { // We are updating our record of how much data was acked self.sent_data_acked = request.ack_num.val(); - has_info = true; + has_info = SessionAction::PushUpstream; // todo: why do we need to push up an empty packet? } if request.seq_num.val() == self.recv_data_amount && !request.data.is_empty() { // Sequence number matches and we have data (So we need to ack) @@ -193,34 +249,29 @@ impl TCPSession { } } TCPSessionState::Closing => { - if self.i_am_closing { - // if has_fin_flag && has_ack_flag { - // response.turn_on_flags(TCP_ACK); - // if self.received_ack_to_fin { - // self.session_state = TCPSessionState::Closed; - // } - // self.received_ack_to_fin = true; - // } - // if has_ack_flag { - // if self.received_ack_to_fin { - // self.session_state = TCPSessionState::Closed; - // } - // self.received_ack_to_fin = true; - // return (None, false); - // } - } else { - response.turn_on_flags(TCP_FIN | TCP_ACK); - self.session_state = TCPSessionState::Closed; + if has_fin_flag && has_ack_flag { + self.has_sent_ack_to_fin_ack = true; + if self.has_recv_ack_to_fin_ack { + self.session_state = TCPSessionState::Closed; + } + response.turn_on_flags(TCP_ACK); + } else if has_ack_flag && self.has_sent_fin_ack { + self.sent_data_amount += 1; + self.has_recv_ack_to_fin_ack = true; + if self.has_sent_ack_to_fin_ack { + self.session_state = TCPSessionState::Closed; + } + return (None, SessionAction::PushUpstream); } + self.recv_data_amount += 1; response.ack_num = Bytefield32::new(self.recv_data_amount); response.seq_num = Bytefield32::new(self.sent_data_amount); } TCPSessionState::Closed => { - // drop all packets -- session is closed - return (None, false); - // shouldn't a closed tcp session be removed? - // todo: define this behavior + // Send end of stream message + // todo: maybe make this a RST packet? + return (None, SessionAction::EndOfStream); } }; } @@ -230,6 +281,6 @@ impl TCPSession { let start_ip = EthernetPacket::packet_size() as usize; response.ip.calculate_checksum(&data[start_ip..start_tcp]); response.calculate_checksum(&data[start_tcp..]); - (Some(response), true) + (Some(response), SessionAction::PushUpstream) } } diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index fab5ef5..62e6607 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -13,6 +13,10 @@ pub async fn tcp_echo_server() { loop { let data_or_err = socket.read(0).await; if let Ok(mut data) = data_or_err { + if data.is_empty() { + // continue if we didn't read any data + continue; + } let user_message = String::from_utf8(data.clone()); match user_message { Ok(message) => { @@ -20,8 +24,10 @@ pub async fn tcp_echo_server() { if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { println!("Closing socket"); let mut exit_message = "Closing socket...\n".as_bytes().to_vec(); - (socket.write(&mut exit_message).await).unwrap(); - socket.close(); + let _ = socket.write(&mut exit_message).await; + println!("Wrote final message to socket"); + socket.close().await; + println!("Closed socket"); break; } }, @@ -29,19 +35,21 @@ pub async fn tcp_echo_server() { } let res_or_err = socket.write(&mut data).await; if let Err(err) = res_or_err { - println!("[ERR] {:?}", err); + println!("[ERR] (Writing): {:?}", err); + socket.close().await; break; } } else if let Err(err) = data_or_err { if err == NetworkErrors::Timeout { continue; } - println!("[ERR] {:?}", err); + println!("[ERR] (Reading): {:?}", err); + socket.close().await; break; } } } }, - Err(err) => println!("[ERR] {:?}", err), + Err(err) => println!("[ERR] (Listening): {:?}", err), } } \ No newline at end of file From 62d74c226a380a8e6dca2567e12ef579070b4eed Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:59:14 -0500 Subject: [PATCH 15/36] TCP done --- kernel/src/network/processing.rs | 1 - kernel/src/network/socket.rs | 19 ++++++++++--------- kernel/src/network/tcp_session.rs | 8 +++----- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index d1c7895..f297c87 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -248,7 +248,6 @@ pub async fn init_packet_processing() { wake_sockets(session_key); } } - // todo: Push end of stream when necessary } _ => {} // ignore others } diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index a1c681a..4c9c93d 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -271,10 +271,16 @@ impl Socket { let session_key = TCPSession::gen_session_key(self.dest_ip, self.dest_port, self.src_port); if rtl_dev_config.tcp_sessions.contains_key(&session_key) { let session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); - if let Ok(pkt) = session.close() { - drop(rtl_dev_config); - // With 5 retries - 'outer: for _ in 0..5 { + let fin_ack_pkt = session.close(); + drop(rtl_dev_config); + if let Ok(pkt) = fin_ack_pkt { + // With 6 retries + 'outer: for _ in 0..6 { + if NET_INFO.config.lock().tcp_sessions.get_mut(&session_key).unwrap().is_closed() { + // If the session is closed, break the outer loop + break 'outer; + } + // Send the FIN-ACK packet disable_network_interrupts(); let mut rtl_dev_info_locked = NET_INFO.lock(); @@ -294,11 +300,6 @@ impl Socket { break; } } - - if NET_INFO.config.lock().tcp_sessions.get_mut(&session_key).unwrap().is_closed() { - // If the session is closed, break the outer loop - break 'outer; - } } } // Remove the session diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index 2b44dd9..04b11a0 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -86,9 +86,6 @@ impl TCPSession { /// Create a packet to close the tcp session with /// Will also transition to the closing state pub fn close(&mut self) -> Result { - if self.session_state != TCPSessionState::Established { - return Err(NetworkErrors::BadSocketState); - } // Transition to the closing state self.session_state = TCPSessionState::Closing; @@ -172,6 +169,7 @@ impl TCPSession { let has_ack_flag = (request.get_flags() & TCP_ACK) != 0; let has_fin_flag = (request.get_flags() & TCP_FIN) != 0; let has_rst_flag = (request.get_flags() & TCP_RST) != 0; + let mut response_action = SessionAction::PushUpstream; // todo: regularly update window size self.window_size = request.sliding_window.val(); let mut response = self.session_template.clone(); @@ -255,6 +253,7 @@ impl TCPSession { self.session_state = TCPSessionState::Closed; } response.turn_on_flags(TCP_ACK); + response_action = SessionAction::EndOfStream; } else if has_ack_flag && self.has_sent_fin_ack { self.sent_data_amount += 1; self.has_recv_ack_to_fin_ack = true; @@ -270,7 +269,6 @@ impl TCPSession { } TCPSessionState::Closed => { // Send end of stream message - // todo: maybe make this a RST packet? return (None, SessionAction::EndOfStream); } }; @@ -281,6 +279,6 @@ impl TCPSession { let start_ip = EthernetPacket::packet_size() as usize; response.ip.calculate_checksum(&data[start_ip..start_tcp]); response.calculate_checksum(&data[start_tcp..]); - (Some(response), SessionAction::PushUpstream) + (Some(response), response_action) } } From cb62cfab6addc325b42c28ceb14ba5c79cc003ac Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sat, 18 Nov 2023 00:28:11 -0500 Subject: [PATCH 16/36] Fix some output stuff --- kernel/src/main.rs | 2 +- kernel/src/network/processing.rs | 1 - kernel/src/task/tcp_echo.rs | 2 +- kernel/src/task/udp_echo.rs | 4 ++-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/kernel/src/main.rs b/kernel/src/main.rs index b0cd786..3273afc 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -94,7 +94,7 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { executor.spawn(Task::new(do_init_dhcp())); // not entirely async, will finish before others are run executor.wait(); // todo: fix wait executor.spawn(Task::new(keyboard::print_keypresses())); - // executor.spawn(Task::new(udp_echo::udp_echo_server())); + executor.spawn(Task::new(udp_echo::udp_echo_server())); executor.spawn(Task::new(tcp_echo::tcp_echo_server())); executor.run(); } diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index f297c87..146dbec 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -213,7 +213,6 @@ pub async fn init_packet_processing() { // ? I think maybe no, because upstream data should be prioritized -- // ? Otherwise I might not have enough processing speed let tcp_session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); - println!("Recv packet! Fin:{} Ack:{} State:{:?}", tcp.get_flags() & TCP_FIN != 0, tcp.get_flags() & TCP_ACK != 0, tcp_session.session_state); // Generate an acknowledgement and determine if the tcp packet has data let ack_pkt = tcp_session.process_recv(&tcp); diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index 62e6607..611680d 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -4,7 +4,7 @@ use crate::{ use alloc::string::String; pub async fn tcp_echo_server() { - let socket_or_err = Socket::open(SocketType::TCP, 6664, 1).await; + let socket_or_err = Socket::open(SocketType::TCP, 6664, 0).await; match socket_or_err { Ok(mut socket_gen) => { // Allow 10 sockets diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index 3134df6..07840fb 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -32,8 +32,8 @@ pub async fn udp_echo_server() { } } else if let Err(err) = data_or_err { if err == NetworkErrors::Timeout { - println!("[INFO] Socket had a timeout"); - break; + println!("[INFO] UDP-Echo ::> Socket had a timeout reading data!!"); + continue; } println!("[ERR] {:?}", err); break; From 5e104173c49fa5511d1690c58ef52292574f747c Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sat, 18 Nov 2023 20:32:39 -0500 Subject: [PATCH 17/36] formatting change to make lines longer (lol we can change it back if so desired) --- build.rs | 8 +- kernel/src/allocator.rs | 8 +- kernel/src/allocator/linked_list.rs | 4 +- kernel/src/framebuffer.rs | 30 ++---- kernel/src/interrupts.rs | 23 ++--- kernel/src/main.rs | 7 +- kernel/src/memory.rs | 5 +- kernel/src/network/arp.rs | 45 +++++++-- kernel/src/network/arp_table.rs | 5 + kernel/src/network/bytefield.rs | 15 ++- kernel/src/network/command_register.rs | 60 +++++++++--- kernel/src/network/constants.rs | 32 ++++--- kernel/src/network/devices.rs | 126 +++++++++++++++---------- kernel/src/network/dhcp.rs | 72 +++++++++++--- kernel/src/network/ethernet.rs | 23 ++++- kernel/src/network/icmp.rs | 1 + kernel/src/network/init.rs | 53 ++++++----- kernel/src/network/ip.rs | 8 +- kernel/src/network/layer.rs | 32 +++---- kernel/src/network/mod.rs | 6 +- kernel/src/network/netsync.rs | 12 +-- kernel/src/network/processing.rs | 64 +++---------- kernel/src/network/raw_array.rs | 6 +- kernel/src/network/raw_socket.rs | 20 ++-- kernel/src/network/rtl8139.rs | 49 +++------- kernel/src/network/socket.rs | 53 +++-------- kernel/src/network/tcp.rs | 8 +- kernel/src/network/tcp_session.rs | 27 ++---- kernel/src/process.rs | 65 +++++-------- kernel/src/serial.rs | 5 +- kernel/src/task/executor.rs | 5 +- kernel/src/task/keyboard.rs | 6 +- kernel/src/task/mod.rs | 4 +- kernel/src/task/tcp_echo.rs | 12 ++- kernel/src/task/timeout.rs | 19 +++- kernel/src/task/udp_echo.rs | 12 ++- rustfmt.toml | 1 + src/main.rs | 12 +-- 38 files changed, 481 insertions(+), 462 deletions(-) create mode 100644 rustfmt.toml diff --git a/build.rs b/build.rs index 347e4df..b380778 100644 --- a/build.rs +++ b/build.rs @@ -9,15 +9,11 @@ fn main() { // create an UEFI disk image (optional) let uefi_path = out_dir.join("uefi.img"); - bootloader::UefiBoot::new(&kernel) - .create_disk_image(&uefi_path) - .unwrap(); + bootloader::UefiBoot::new(&kernel).create_disk_image(&uefi_path).unwrap(); // create a BIOS disk image let bios_path = out_dir.join("bios.img"); - bootloader::BiosBoot::new(&kernel) - .create_disk_image(&bios_path) - .unwrap(); + bootloader::BiosBoot::new(&kernel).create_disk_image(&bios_path).unwrap(); // pass the disk image paths as env variables to the `main.rs` println!("cargo:rustc-env=UEFI_PATH={}", uefi_path.display()); diff --git a/kernel/src/allocator.rs b/kernel/src/allocator.rs index 37b7b1d..b30ebfb 100644 --- a/kernel/src/allocator.rs +++ b/kernel/src/allocator.rs @@ -16,9 +16,7 @@ pub const HEAP_SIZE: usize = 100 * 1024; static ALLOCATOR: Locked = Locked::new(FixedSizeBlockAllocator::new()); use x86_64::{ - structures::paging::{ - mapper::MapToError, FrameAllocator, Mapper, Page, PageTableFlags, Size4KiB, - }, + structures::paging::{mapper::MapToError, FrameAllocator, Mapper, Page, PageTableFlags, Size4KiB}, VirtAddr, }; @@ -35,9 +33,7 @@ pub fn init_heap( }; for page in page_range { - let frame = frame_allocator - .allocate_frame() - .ok_or(MapToError::FrameAllocationFailed)?; + let frame = frame_allocator.allocate_frame().ok_or(MapToError::FrameAllocationFailed)?; let flags = PageTableFlags::PRESENT | PageTableFlags::WRITABLE; unsafe { mapper.map_to(page, frame, flags, frame_allocator)?.flush() }; } diff --git a/kernel/src/allocator/linked_list.rs b/kernel/src/allocator/linked_list.rs index 9a1eeeb..9fd0077 100644 --- a/kernel/src/allocator/linked_list.rs +++ b/kernel/src/allocator/linked_list.rs @@ -28,9 +28,7 @@ pub struct LinkedListAllocator { impl LinkedListAllocator { /// Creates an empty LinkedListAllocator. pub const fn new() -> Self { - Self { - head: ListNode::new(0), - } + Self { head: ListNode::new(0) } } /// Initialize the allocator with the given heap bounds. diff --git a/kernel/src/framebuffer.rs b/kernel/src/framebuffer.rs index 10d2620..70971ac 100644 --- a/kernel/src/framebuffer.rs +++ b/kernel/src/framebuffer.rs @@ -4,9 +4,7 @@ use bootloader_api::{ }; use core::{fmt, ptr}; use font_constants::BACKUP_CHAR; -use noto_sans_mono_bitmap::{ - get_raster, get_raster_width, FontWeight, RasterHeight, RasterizedChar, -}; +use noto_sans_mono_bitmap::{get_raster, get_raster_width, FontWeight, RasterHeight, RasterizedChar}; use spin::Mutex; /// Additional vertical space between lines @@ -38,11 +36,7 @@ mod font_constants { /// Returns the raster of the given char or the raster of [`font_constants::BACKUP_CHAR`]. fn get_char_raster(c: char) -> RasterizedChar { fn get(c: char) -> Option { - get_raster( - c, - font_constants::FONT_WEIGHT, - font_constants::CHAR_RASTER_HEIGHT, - ) + get_raster(c, font_constants::FONT_WEIGHT, font_constants::CHAR_RASTER_HEIGHT) } get(c).unwrap_or_else(|| get(BACKUP_CHAR).expect("Should get raster of backup char.")) } @@ -103,8 +97,7 @@ impl FrameBufferWriter { if new_xpos >= self.width() { self.newline(); } - let new_ypos = - self.y_pos + font_constants::CHAR_RASTER_HEIGHT.val() + BORDER_PADDING; + let new_ypos = self.y_pos + font_constants::CHAR_RASTER_HEIGHT.val() + BORDER_PADDING; if new_ypos >= self.height() { self.clear(); } @@ -139,8 +132,7 @@ impl FrameBufferWriter { }; let bytes_per_pixel = self.info.bytes_per_pixel; let byte_offset = pixel_offset * bytes_per_pixel; - self.framebuffer[byte_offset..(byte_offset + bytes_per_pixel)] - .copy_from_slice(&color[..bytes_per_pixel]); + self.framebuffer[byte_offset..(byte_offset + bytes_per_pixel)].copy_from_slice(&color[..bytes_per_pixel]); let _ = unsafe { ptr::read_volatile(&self.framebuffer[byte_offset]) }; } } @@ -172,17 +164,11 @@ pub fn _print(args: fmt::Arguments) { } pub fn setup(boot_info: &mut BootInfo) { - let framebuffer = core::mem::replace( - &mut boot_info.framebuffer, - bootloader_api::info::Optional::None, - ) - .into_option() - .unwrap(); + let framebuffer = core::mem::replace(&mut boot_info.framebuffer, bootloader_api::info::Optional::None) + .into_option() + .unwrap(); let framebuffer_info = framebuffer.info(); - *WRITER.lock() = Some(FrameBufferWriter::new( - framebuffer.into_buffer(), - framebuffer_info, - )); + *WRITER.lock() = Some(FrameBufferWriter::new(framebuffer.into_buffer(), framebuffer_info)); } #[macro_export] diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index f0bdbe9..0dab7c5 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -1,15 +1,14 @@ +use crate::println; use crate::{gdt, hlt_loop, task::timeout::poll_timeouts}; use lazy_static::lazy_static; -use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame, PageFaultErrorCode}; -use crate::println; use pic8259::ChainedPics; use spin; +use x86_64::structures::idt::{InterruptDescriptorTable, InterruptStackFrame, PageFaultErrorCode}; pub const PIC_1_OFFSET: u8 = 32; pub const PIC_2_OFFSET: u8 = PIC_1_OFFSET + 8; -pub static PICS: spin::Mutex = - spin::Mutex::new(unsafe { ChainedPics::new(PIC_1_OFFSET, PIC_2_OFFSET) }); +pub static PICS: spin::Mutex = spin::Mutex::new(unsafe { ChainedPics::new(PIC_1_OFFSET, PIC_2_OFFSET) }); #[derive(Debug, Clone, Copy)] #[repr(u8)] @@ -96,10 +95,7 @@ extern "x86-interrupt" fn breakpoint_handler(stack_frame: InterruptStackFrame) { println!("EXCEPTION: BREAKPOINT\n{:#?}", stack_frame); } -extern "x86-interrupt" fn page_fault_handler( - stack_frame: InterruptStackFrame, - error_code: PageFaultErrorCode, -) { +extern "x86-interrupt" fn page_fault_handler(stack_frame: InterruptStackFrame, error_code: PageFaultErrorCode) { use x86_64::registers::control::Cr2; println!("EXCEPTION: PAGE FAULT"); @@ -109,18 +105,14 @@ extern "x86-interrupt" fn page_fault_handler( hlt_loop(); } -extern "x86-interrupt" fn double_fault_handler( - stack_frame: InterruptStackFrame, - _error_code: u64, -) -> ! { +extern "x86-interrupt" fn double_fault_handler(stack_frame: InterruptStackFrame, _error_code: u64) -> ! { panic!("EXCEPTION: DOUBLE FAULT\n{:#?}", stack_frame); } extern "x86-interrupt" fn timer_interrupt_handler(_stack_frame: InterruptStackFrame) { poll_timeouts(); unsafe { - PICS.lock() - .notify_end_of_interrupt(InterruptIndex::Timer.as_u8()); + PICS.lock().notify_end_of_interrupt(InterruptIndex::Timer.as_u8()); } } @@ -132,8 +124,7 @@ extern "x86-interrupt" fn keyboard_interrupt_handler(_stack_frame: InterruptStac crate::task::keyboard::add_scancode(scancode); unsafe { - PICS.lock() - .notify_end_of_interrupt(InterruptIndex::Keyboard.as_u8()); + PICS.lock().notify_end_of_interrupt(InterruptIndex::Keyboard.as_u8()); } } diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 3273afc..dcf8457 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -17,8 +17,8 @@ use kernel::{ }, println, task::keyboard, - task::{udp_echo, tcp_echo}, task::{executor::Executor, Task}, + task::{tcp_echo, udp_echo}, }; extern crate alloc; @@ -45,7 +45,6 @@ pub static BOOTLOADER_CONFIG: BootloaderConfig = { entry_point!(kernel_main, config = &BOOTLOADER_CONFIG); - async fn do_init_dhcp() { let status_init_dhcp = init_dhcp(4).await; if !status_init_dhcp { @@ -76,9 +75,7 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { if rtl_driver.is_none() { panic!("Cannot find network card"); } - rtl_driver - .unwrap() - .init(&mut frame_allocator, phys_mem_offset) + rtl_driver.unwrap().init(&mut frame_allocator, phys_mem_offset) }; // so that the NET INFO gets released if !status_init { println!("[ERR] Cannot init RTL8139"); diff --git a/kernel/src/memory.rs b/kernel/src/memory.rs index 082eb21..1934ef5 100644 --- a/kernel/src/memory.rs +++ b/kernel/src/memory.rs @@ -17,10 +17,7 @@ impl BootInfoFrameAllocator { /// memory map is valid. The main requirement is that all frames that are marked /// as `USABLE` in it are really unused. pub unsafe fn init(memory_map: &'static MemoryRegions) -> Self { - BootInfoFrameAllocator { - memory_map, - next: 0, - } + BootInfoFrameAllocator { memory_map, next: 0 } } /// Returns an iterator over the usable frames specified in the memory map. diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs index 57e2396..b298656 100644 --- a/kernel/src/network/arp.rs +++ b/kernel/src/network/arp.rs @@ -5,22 +5,34 @@ use super::{ }; use alloc::vec; use alloc::vec::Vec; + +/// An arp packet, implements Layer (42 bytes) #[derive(Debug)] pub struct ArpPacket { + /// The parent packet pub eth: EthernetPacket, + /// The hardware type => Ethernet is 1 hardware_type: Bytefield16, + /// The protocol => 0x0800 is IPv4 protocol_type: Bytefield16, + /// Should be 6 hardware_address_length: u8, + /// Should be 4 protocol_address_length: u8, + /// 1 for request, 2 for reply operation: Bytefield16, + /// The sender's mac address pub sender_mac: Bytefield48, + /// The sender's IP address pub sender_ip: Bytefield32, + /// The mac address of the receiver (0 if a request, this is the question field) pub recp_mac: Bytefield48, + /// The recepient IP, this is also part of the question if a request pub recp_ip: Bytefield32, } impl ArpPacket { - // Create an empty packet with all 0s + /// Create an empty packet with all 0s pub fn new() -> Self { ArpPacket { eth: EthernetPacket::new(), @@ -36,19 +48,27 @@ impl ArpPacket { } } + /// Generate a ARP packet with + /// - eth_layer: is the ethernet frame associated with the packet + /// - source_ip: is the machine's IP address + /// - recp_ip: is the destination's IP address + /// - is_req: is true if the packet is a request and false if the packet is a response pub fn gen(eth_layer: EthernetPacket, source_ip: u32, recp_ip: u32, is_req: bool) -> Self { + // Extract the recp_mac and sender_mac from the ethernet layer let recp_mac = eth_layer.dest_mac; let sender_mac = eth_layer.src_mac; + // Assert the eth layer matches us assert!(eth_layer.packet_type == EthType::Arp); + // Construct the arp packet ArpPacket { eth: eth_layer, - hardware_type: Bytefield16::new(0x1), // ethernet - protocol_type: Bytefield16::new(0x0800), // ipv4 - hardware_address_length: 6, // ethernet is the value 6 - protocol_address_length: 4, // ethernet is the value 4 + hardware_type: Bytefield16::new(0x1), // ethernet + protocol_type: Bytefield16::new(0x0800), // ipv4 + hardware_address_length: 6, // ethernet is the value 6 + protocol_address_length: 4, // ethernet is the value 4 operation: Bytefield16::new(if is_req { 1 } else { 2 }), // Request is 1, Reply is 2 sender_mac, - sender_ip: Bytefield32::new(source_ip), // what is my ip + sender_ip: Bytefield32::new(source_ip), recp_mac, recp_ip: Bytefield32::new(if is_req { 0 } else { recp_ip }), } @@ -56,13 +76,20 @@ impl ArpPacket { } impl Layer for ArpPacket { + /// The input layer for parse type Input = EthernetPacket; + + /// Parsing an ARP packet requires: + /// - eth_layer: a parsed Ethernet packet + /// - bytevec: the data to parse, with trailing but it starts where the packet must begin fn parse(eth_layer: EthernetPacket, bytevec: &[u8]) -> (Self, usize, LayerType) { let mut packet = ArpPacket::new(); // create an empty packet // Read ethernet packet and 28 bytes let mut i = 0; + // Extract the eth_layer packet.eth = eth_layer; + // Read byte by byte into the struct packet.hardware_type = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.protocol_type = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); @@ -72,12 +99,14 @@ impl Layer for ArpPacket { packet.sender_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); packet.recp_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); packet.recp_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); - assert!(i == 28); // Arp packet should be 28 bytes + // Return the packet, the amount of data consumed, and the next layer type (end of parse) (packet, i, LayerType::END) } + /// Serialize the packet into a vector of bytes, ready to send over the network fn serialize(&self) -> Vec { + // Create a vector and serialize it let mut res = vec![]; res.extend(self.eth.serialize()); res.extend(self.hardware_type.data); @@ -92,7 +121,9 @@ impl Layer for ArpPacket { res } + /// The amount of data that belongs to the packet-type fn packet_size() -> u16 { + // 28 bytes 28 } } diff --git a/kernel/src/network/arp_table.rs b/kernel/src/network/arp_table.rs index 02b0835..2771eab 100644 --- a/kernel/src/network/arp_table.rs +++ b/kernel/src/network/arp_table.rs @@ -1,5 +1,10 @@ +/// An entry in the ARP table +/// todo: This prob shouldn't get its own file pub struct ArpEntry { + /// The mac address of the ARP entry pub mac: u64, + /// The IP address of the ARP entry pub ip: u32, + /// How many timer interrupts before this entry expires pub expires: u16, } diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs index a1fdb65..30ab0e1 100644 --- a/kernel/src/network/bytefield.rs +++ b/kernel/src/network/bytefield.rs @@ -1,23 +1,28 @@ use core::ops::{Index, IndexMut}; -// N.B.: Bytefields will swap the endianness of the values when created +// N.B.: Bytefields will swap the endianness of the values when created // --> therefore serializing will create network byte order (big endian) Bytefields // --> AND deserializing will create host byte order (small endian) Bytefields // (That's why val() doesn't swap the byte order!) // todo: refactor the api to track the state of the byte order? (would this work?) +/// A structure to represent data as an array of bytes +/// Will store in both host and network byte-order. It is dependent on it's construction #[derive(Debug, Clone, Copy)] pub struct Bytefield { pub data: [u8; N], } impl Bytefield { + /// Return the bytefield with swapped endianness pub fn swapped_endianness(self) -> Self { let mut data = self.data; data.reverse(); Self { data } } + /// Will parse the first N bytes from the bytevec and swap the order (network-->host) + /// AND It will increment i by N. pub fn read_inc(bytevec: &[u8], i: &mut usize) -> Self { let mut data = [0_u8; N]; for i in 0..N { @@ -27,6 +32,7 @@ impl Bytefield { Self { data } } + /// Get the number of bytes in the type pub const fn size() -> usize { N } @@ -34,12 +40,14 @@ impl Bytefield { impl Index for Bytefield { type Output = u8; + /// Read from the bytefield's inner array fn index(&self, i: usize) -> &u8 { &self.data[i] } } impl IndexMut for Bytefield { + /// Get a mutable entry to the bytefield's inner array fn index_mut(&mut self, i: usize) -> &mut u8 { &mut self.data[i] } @@ -50,6 +58,7 @@ macro_rules! bytefield_int { pub type $t = Bytefield<$size>; impl $t { + /// Will construct a bytefield with value: val in swapped order (host-->network) pub fn new(val: $int) -> Self { let mut data = [0; $size]; for i in 0..$size { @@ -58,7 +67,8 @@ macro_rules! bytefield_int { $t { data } } - // get the data in the natural endianness + /// Get the data in the stored endianness + /// (if host, will return in host order) pub fn val(&self) -> $int { let mut res = 0; for i in 0..$size { @@ -70,6 +80,7 @@ macro_rules! bytefield_int { }; } +// Define different bytefield types bytefield_int!(Bytefield8, u8, 1); bytefield_int!(Bytefield16, u16, 2); bytefield_int!(Bytefield32, u32, 4); diff --git a/kernel/src/network/command_register.rs b/kernel/src/network/command_register.rs index 669ab81..6de30dc 100644 --- a/kernel/src/network/command_register.rs +++ b/kernel/src/network/command_register.rs @@ -1,22 +1,30 @@ +/// CommandRegister is an object to represent the command register for PCI devices +/// See https://wiki.osdev.org/PCI#Command_Register #[derive(Debug, Clone)] pub struct CommandRegister { + /// The value of the command register cr: u16, } impl CommandRegister { + /// Construct a new instance with value, cr pub fn new(cr: u16) -> Self { CommandRegister { cr } } - // basic getter for internal data + /// Getter for internal data (not pub on "cr" because shouldn't be allowed to modify "cr" willy-nilly) pub fn data(&self) -> u16 { self.cr } - // 0th bit + /// Get the io_space bit, the 0th bit + /// If on, the device can respond to I/O space accesses pub fn get_io_space_bit(&self) -> bool { (self.cr & 0x1) != 0 } + /// Set the io_space bit, the 0th bit + /// - is_on: if true, will turn on. Else turn off + /// If on, the device can respond to I/O space accesses pub fn set_io_space_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x1, @@ -24,10 +32,14 @@ impl CommandRegister { } } - // 1st bit + /// Get the memory_space bit, the 1st bit + /// If on, the device can respond to memory space accesses pub fn get_memory_space_bit(&self) -> bool { (self.cr & 0x2) != 0 } + /// Set the memory_space bit, the 1st bit + /// - is_on: if true, will turn on. Else turn off + /// If on, the device can respond to memory space accesses pub fn set_memory_space_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x2, @@ -35,10 +47,14 @@ impl CommandRegister { } } - // 2nd bit + /// Get the bus master bit, the 2nd bit + /// If on, the device can behave as a bus master; otherwise the device cannot generate PCI accesses pub fn get_bus_master_bit(&self) -> bool { (self.cr & 0x4) != 0 } + /// Set the bus master bit, the 2nd bit + /// - is_on: if true, will turn on. Else turn off + /// If on, the device can behave as a bus master; otherwise the device cannot generate PCI accesses pub fn set_bus_master_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x4, @@ -46,36 +62,52 @@ impl CommandRegister { } } - // 3rd bit + /// Get the special_cycles bit, the 3rd bit + /// If on, the device can monitor special cycle operations pub fn get_special_cycles_bit(&self) -> bool { (self.cr & 0x8) != 0 } - // 4th bit + /// Get the memory_write_invalidate_enable bit, the 4th bit + /// If on, the device can generate the memory write and invalidate command + /// (to invalidate cached data of a snooped memory location) pub fn get_memory_write_invalidate_enable_bit(&self) -> bool { (self.cr & 0x10) != 0 } - // 5th bit + /// Get the vga palette snoop bit, 5th bit + /// If on, the device does not respond to palette register writes and will snoop the data + /// (This has to do with setting colors of the VGA buffer?) pub fn get_vga_palette_snoop_bit(&self) -> bool { (self.cr & 0x20) != 0 } - // 6th bit + /// Get the parity err response bit, 6th bit + /// If on, the device will take normal action when parity error is detected + /// Otherwise, if detected, the device will set bit 15 of the status register and will continue as normal pub fn get_parity_err_res_bit(&self) -> bool { (self.cr & 0x40) != 0 } + /// Set the parity err response bit, 6th bit + /// - is_on: if true, will turn on. Else turn off + /// If on, the device will take normal action when parity error is detected + /// Otherwise, if detected, the device will set bit 15 of the status register and will continue as normal pub fn set_parity_err_res_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x40, false => self.cr &= !0x40, } } + // Bit 7 is reserved and RO - // 8th bit + /// Get the serr enable bit, 8th bit + /// If on, the SERR# driver is enabled pub fn get_serr_enable_bit(&self) -> bool { (self.cr & 0x100) != 0 } + /// Set the serr enable bit, 8th bit + /// - is_on: if true, will turn on. Else turn off + /// If on, the SERR# driver is enabled pub fn set_serr_enable_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x100, @@ -83,19 +115,25 @@ impl CommandRegister { } } - // 9th bit + /// Get the fast back-back enable bit, 9th bit + /// If on, a device is allowed to generate fast back-2-back transactions. pub fn get_fast_back_to_back_enable_bit(&self) -> bool { (self.cr & 0x200) != 0 } - // 10th bit + /// Get the interrupt disabled bit, 10th bit + /// If on, the interrupts from this device are disabled pub fn get_interrupt_disable_bit(&self) -> bool { (self.cr & 0x400) != 0 } + /// Set interrupt disabled bit, 10th bit + /// - is_on: if true, will turn on. Else turn off + /// If on, the interrupts from this device are disabled pub fn set_interrupt_disable_bit(&mut self, is_on: bool) { match is_on { true => self.cr |= 0x400, false => self.cr &= !0x400, } } + // Bits 11-15 are reserved } diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index fa956b9..ce5e374 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -1,10 +1,14 @@ +// PCI things +pub const PCI_CONFIG_ADDRESS: u16 = 0xCF8; // specifies the configuration address that is required to be accesses +pub const PCI_CONFIG_DATA: u16 = 0xCFC; // will actually generate the configuration access and will transfer the data to or from the CONFIG_DATA register + // Broadcast constants -pub const BROADCAST_ADDR: u32 = 0xFFFFFFFF; -pub const BROADCAST_MAC: u64 = 0xFFFFFFFFFFFF; +pub const BROADCAST_ADDR: u32 = 0xFFFFFFFF; // broadcast IP address +pub const BROADCAST_MAC: u64 = 0xFFFFFFFFFFFF; // broadcast MAC address // Common port numbers -pub const DHCP_CLIENT_PORT: u64 = 68; -pub const DHCP_SERVER_PORT: u64 = 67; +pub const DHCP_CLIENT_PORT: u64 = 68; // client port for dhcp requests +pub const DHCP_SERVER_PORT: u64 = 67; // server port for dhcp requests pub const ARP_PORT: u64 = u16::MAX as u64 + 2; // we are using ports above u16 to allow for extended our open "ports" map // RTL device specific info @@ -12,9 +16,9 @@ pub const RTL_VEND: u32 = 0x10EC; // Vendor ID for the rtl8139 pub const RTL_DEV: u32 = 0x8139; // Device ID for the rtl8139 pub const TOK: u16 = 1 << 2; // TOK bit pub const ROK: u16 = 1 << 0; // ROK bit -pub const TRANSMIT_REG: [u32; 4] = [0x20, 0x24, 0x28, 0x2C]; -pub const TRANSMIT_CMD: [u32; 4] = [0x10, 0x14, 0x18, 0x1C]; -pub const INTERRUPT_MASK: u16 = 0x01 | 0x04 | 0x10 | 0x08 | 0x02; +pub const TRANSMIT_REG: [u32; 4] = [0x20, 0x24, 0x28, 0x2C]; // 4 transmit registers +pub const TRANSMIT_CMD: [u32; 4] = [0x10, 0x14, 0x18, 0x1C]; // 4 cmd registers +pub const INTERRUPT_MASK: u16 = 0x01 | 0x04 | 0x10 | 0x08 | 0x02; // interrupt mask pub const RX_BUFFER_SIZE: u16 = 8192; // how big the buffer is pub const CR_RST: u16 = 0x10; // Reset, set to 1 to invoke S/W reset, held to 1 while resetting pub const CR_RE: u8 = 0x08; // Reciever Enable, enables receiving @@ -22,12 +26,12 @@ pub const CR_TE: u8 = 0x04; // Transmitter Enable, enables transmitting pub const CR_BUFE: u8 = 0x01; // Rx buffer is empty pub const CR: u32 = 0x37; // command register (1byte) pub const CAPR: u32 = 0x38; // current address of packet read (2byte, C mode, initial value 0xFFF0) -pub const RX_READ_PTR_MASK: u16 = !0x3; // +pub const RX_READ_PTR_MASK: u16 = !0x3; // Used to align to 8 bytes in RTL8139 driver // TCP Constants -pub const TCP_FIN: u8 = 0x1; -pub const TCP_SYN: u8 = 0x2; -pub const TCP_RST: u8 = 0x4; -pub const TCP_PSH: u8 = 0x8; -pub const TCP_ACK: u8 = 0x10; -pub const TCP_URG: u8 = 0x20; \ No newline at end of file +pub const TCP_FIN: u8 = 0x1; // TCP FIN flag (gracefully closing connection) +pub const TCP_SYN: u8 = 0x2; // TCP SYN flag (synchronizing to open connection) +pub const TCP_RST: u8 = 0x4; // TCP RST flag (forcefully terminate connection) +pub const TCP_PSH: u8 = 0x8; // TCP PSH flag (push data to application layer) +pub const TCP_ACK: u8 = 0x10; // TCP ACK flag (acknowledge something -- multi-use) +pub const TCP_URG: u8 = 0x20; // TCP URG flag (urgent data) diff --git a/kernel/src/network/devices.rs b/kernel/src/network/devices.rs index b0fc5dd..f66a4d6 100644 --- a/kernel/src/network/devices.rs +++ b/kernel/src/network/devices.rs @@ -1,23 +1,12 @@ use alloc::vec::Vec; use x86_64::instructions::port::Port; -use super::command_register::CommandRegister; - -const CONFIG_ADDRESS: u16 = 0xCF8; -const CONFIG_DATA: u16 = 0xCFC; - -#[derive(Clone)] -pub struct Device { - pub bus: u8, - pub slot: u8, - pub vendor_id: u16, - pub device_id: u16, - pub class_code: PCIClassCodes, - pub sub_class: u8, - pub io_base: Option, - pub irq: Option, -} +use super::{ + command_register::CommandRegister, + constants::{PCI_CONFIG_ADDRESS, PCI_CONFIG_DATA}, +}; +/// Types of devices for PCI #[derive(Debug, PartialEq, Eq, Clone)] pub enum PCIClassCodes { Unclassified, @@ -46,6 +35,7 @@ pub enum PCIClassCodes { } impl PCIClassCodes { + /// Match a PCI class code (class_code) with a number pub fn from(class_code: u8) -> Self { match class_code { 0x0 => Self::Unclassified, @@ -75,39 +65,46 @@ impl PCIClassCodes { } } -// Write into the config address +/// Write into the config address, which helps PCI determine which device's config space we are modifying/reading +/// bus + slot + func + offset identifies a unique device fn create_confg_address(bus: u8, slot: u8, func: u8, offset: u8) { + // Cast the u32 let lbus = bus as u32; let lslot = slot as u32; let lfunc = func as u32; - let address = - (lbus << 16) | (lslot << 11) | (lfunc << 8) | ((offset as u32) & 0xFC) | 0x80000000_u32; - let mut port = Port::::new(CONFIG_ADDRESS); - // Write the address + // Create the address of the PCI device information + let address = (lbus << 16) | (lslot << 11) | (lfunc << 8) | ((offset as u32) & 0xFC) | 0x80000000_u32; + // Set the PCI_CONFIG_ADDRESS to point to that device + let mut port = Port::::new(PCI_CONFIG_ADDRESS); unsafe { port.write(address) }; } -// Query PCI for 32 bits of data of a device defined by (bus, slot) -// Will offset into the PCI Configuration Header by offset -// Func is used for multi-function devices... (don't think that will be used with network card) +/// Query PCI for 32 bits of data of a device defined by (bus + slot) +/// Will offset into the PCI Configuration Header by offset +/// Func is used for multi-function devices... (don't think that will be used with network card) fn pci_config_read_dword(bus: u8, slot: u8, func: u8, offset: u8) -> u32 { + // Load the correct bus/slot id for our intended device create_confg_address(bus, slot, func, offset); - let mut port = Port::::new(CONFIG_DATA); - // Read the data + // Read the data from the PCI_CONFIG_DATA register to get the information from the register + let mut port = Port::::new(PCI_CONFIG_DATA); let data: u32 = unsafe { port.read() }; data } /// Query PCI to write 16 bits of data of a device defined by (bus, slot) /// Will offset into the PCI configuration header by offset -// Func is used for multi-function devices... (don't think that will be used with network card) +/// Func is used for multi-function devices... (don't think that will be used with network card) fn pci_config_write_word(bus: u8, slot: u8, func: u8, offset: u8, word: u16) { + // Load the correct bus/slot id for our intended device create_confg_address(bus, slot, func, offset); - let mut port = Port::::new(CONFIG_DATA); - // Read the data + // Write the data to the PCI_CONFIG_DATA register to set the information in the device we identified above + let mut port = Port::::new(PCI_CONFIG_DATA); + // Read the data first (only modifying the first half of 32 bytes) let data: u32 = unsafe { port.read() }; + // Modify the red data to be what is should now be let new_data = (data & 0xFFFF0000) | word as u32; + // Commit the modification unsafe { port.write(new_data); } @@ -116,25 +113,25 @@ fn pci_config_write_word(bus: u8, slot: u8, func: u8, offset: u8, word: u16) { /// Check if a device exists at (bus, slot) /// If it does, it will return the vendor ID fn pci_check_vendor(bus: u8, slot: u8) -> Option { - /* Try and read the first configuration register. Since there are no - * vendors that == 0xFFFF, it must be a non-existent device. */ - + // Read the bus and slot to get the vendor ID let vendor = (pci_config_read_dword(bus, slot, 0, 0) & 0xFFFF) as u16; + // If vendor is 0xFFFF, we have a non-existent connection on PCI if vendor != 0xFFFF { return Some(vendor); } None } -// Assumes a device at (bus, slot) -// Will extract the class code from the configuration space +/// Assumes a device exists at (bus, slot) +/// Will try extract the class code from the configuration space fn pci_get_device_id(bus: u8, slot: u8) -> u16 { + // Read the device id from the second half of the PCI configuration space at offset 0 let device_id = (pci_config_read_dword(bus, slot, 0, 0) >> 16) & 0xFFFF; device_id as u16 } -// Assumes a device at (bus, slot) -// Will extract the class code (class and subclass) from the configuration space +/// Assumes a device exists at (bus, slot) +/// Will extract the class code (class and subclass) from the configuration space fn pci_get_class_code(bus: u8, slot: u8) -> (PCIClassCodes, u8) { let code = pci_config_read_dword(bus, slot, 0, 0x8); // let _revision_id = (code & 0xFF) as u8; @@ -144,36 +141,42 @@ fn pci_get_class_code(bus: u8, slot: u8) -> (PCIClassCodes, u8) { (PCIClassCodes::from(class), subclass) } -// Read the interrupt line of the PCI Configuration address space -// "For the x86 architecture this register corresponds to the PIC IRQ numbers 0-15" and a value of 0xFF defines no connection +/// Read the interrupt line of the PCI Configuration address space +/// "For the x86 architecture this register corresponds to the PIC IRQ numbers 0-15" and a value of 0xFF defines no connection fn pci_get_irq(bus: u8, slot: u8) -> Option { + // Read the IRQ number and return it, if it exists (not 0xFF) let irq = (pci_config_read_dword(bus, slot, 0, 0x3C) & 0xFF) as u8; - if irq == 0xFF { - None - } else { - Some(irq) + if irq != 0xFF { + return Some(irq); } + None } +/// Assumes a device exists at (bus, slot) +/// Read the command register from the PCI configuration space fn pci_get_cmd_reg(bus: u8, slot: u8) -> CommandRegister { let cr = pci_config_read_dword(bus, slot, 0, 0x4); CommandRegister::new((cr & 0xFFFF) as u16) } -// Set the command register +/// Assumes a device exists at (bus, slot) +/// Will set the command register to be it's new value fn pci_set_cmd_reg(bus: u8, slot: u8, cr: CommandRegister) { // Write the register with cr pci_config_write_word(bus, slot, 0, 0x4, cr.data()); } -// Get io base of header type 0x0 +/// Assumes a device exists at (bus, slot) +/// Will get io base of header type 0x0 fn pci_get_io_base(bus: u8, slot: u8) -> Option { let mut cr = pci_get_cmd_reg(bus, slot); let org = cr.clone(); // Must disable IO and memory space decode before accessing bar cr.set_memory_space_bit(false); cr.set_io_space_bit(false); + // todo: why doesn't this work? --> pci_set_cmd_reg(bus, slot, cr); for i in 0..5 { + // Read the BAR let bar = pci_config_read_dword(bus, slot, 0, 0x10 + (i * 0x4)); if (bar & 0x1) == 0x1 { // We have an IO address @@ -186,24 +189,47 @@ fn pci_get_io_base(bus: u8, slot: u8) -> Option { None } +/// A struct to represent a PCI-discovered device +#[derive(Clone)] +pub struct Device { + /// The bus number + pub bus: u8, + /// The slot number + pub slot: u8, + /// The vendor id + pub vendor_id: u16, + /// The device id + pub device_id: u16, + /// The type of device + pub class_code: PCIClassCodes, + /// Subclass + pub sub_class: u8, + /// IO_Base number + pub io_base: Option, + /// IRQ number + pub irq: Option, +} + impl Device { - // Read the command register of the device + /// Read the command register of the device pub fn read_command_register(&self) -> CommandRegister { pci_get_cmd_reg(self.bus, self.slot) } - // Write to the command register of the device + /// Write to the command register of the device pub fn write_command_register(&self, new_val: CommandRegister) { pci_set_cmd_reg(self.bus, self.slot, new_val); } } -// Will return X devices -// TODO: multiprocessing safety? +/// Will return all found devices pub fn scan_devices() -> Vec { + // Create an empty vector of u8 pairs let mut device_bus_slots: Vec<(u8, u8)> = Vec::new(); for bus in 0..255 { for slot in 0..31 { + // Iterate through all possible buses and slots and determine if a device exists. + // If so append the (bus, slot) pair match pci_check_vendor(bus, slot) { Some(_) => { device_bus_slots.push((bus, slot)); @@ -214,7 +240,9 @@ pub fn scan_devices() -> Vec { } let mut results: Vec = Vec::new(); + // Iterate through all the bus_slot pairs for bus_slot in device_bus_slots.iter() { + // Extract the necessary information to create the device object let bus = bus_slot.0; let slot = bus_slot.1; let vendor_id = pci_check_vendor(bus, slot).unwrap(); @@ -233,6 +261,6 @@ pub fn scan_devices() -> Vec { io_base, }); } - + // Return the results of the device scan results } diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index b06f0e8..69b924c 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -7,41 +7,65 @@ use super::{ udp::UDPPacket, }; +/// A wrapper for thread-safe id generation struct WrappedU32 { data: u32, } impl WrappedU32 { + /// Get the value pub fn get(&self) -> u32 { self.data } + /// Set the value pub fn set(&mut self, data: u32) { self.data = data; } } +// todo: Extract this generator (and refactor to a lockless approach) to the crypto folder +/// Generator for random IDs static mut ID_GEN: spin::Mutex = spin::Mutex::new(WrappedU32 { data: 0 }); + +/// A DHCP packet, implements Layer (usually 300 bytes or more) #[derive(Debug)] pub struct DHCPPacket { - pub udp: UDPPacket, // public for checksumming + /// The parent UDP packet + pub udp: UDPPacket, + /// The DHCP operation, 1 is request and 2 is reply op_code: u8, + /// 1 is ethernet. More information on hardware types can be found here http://www.tcpipguide.com/free/t_DHCPMessageFormat.htm hardware_type: u8, + /// For ethernet, the value is 6. hardware_address_length: u8, + /// Set to 0 by a client before transmitting a request and used by relay agents to control the forwarding of BOOTP and/or DHCP messages. hops: u8, + /// A 32-bit identification field generated by the client, to allow it to match up the request with replies received from DHCP servers. transaction_identifier: Bytefield32, + /// In BOOTP, not used. For DHCP, it is defined as the # of seconds elapsed since a client began an attempt to get an IP seconds: Bytefield16, + /// Contains one field - broadcast (if doesn't know who the DHCP server is) flags: Bytefield16, + /// The client's IP (can be 0, if unknown) pub client_ip: Bytefield32, + /// The ip the server is assigning to you pub my_ip: Bytefield32, + /// The DHCP server IP pub server_ip: Bytefield32, + /// The gateway IP pub gateway_ip: Bytefield32, + /// The client's hardware address. Mac address typically but its 16 bytes client_hardware_address: Bytefield128, + /// Server name sname: [u8; 64], + /// Boot Filename file: [u8; 128], + /// DHCP options -- fixed length in BOOTP options: [u8; 64], // variable length } impl DHCPPacket { + /// Generate a new DHCP packet with all 0s pub fn new() -> Self { DHCPPacket { udp: UDPPacket::new(), @@ -64,26 +88,33 @@ impl DHCPPacket { } } - pub fn gen(udp_packet: UDPPacket, ip_address: Option, mac_address: u64) -> Self { + /// Generate a DHCP packet with + /// - udp_layer: is the udp data associated with the packet + /// - ip_address: is the machine's IP address, or None if unknown + /// - mac_address: is the machine's MAC address + pub fn gen(udp_layer: UDPPacket, ip_address: Option, mac_address: u64) -> Self { + // Generate a unique identifier for the client let identification = unsafe { let mut id_gen = ID_GEN.lock(); let id_gen_old = id_gen.get(); id_gen.set((id_gen_old + 1) % 0xFFFF); Bytefield32::new(id_gen.get()) }; + // Get the mac address and place it into the client_hardware_address field let mac = Bytefield48::new(mac_address); let mut client_hardware_address = Bytefield128::new(0); for i in 0..6 { client_hardware_address[i] = mac[i]; } + // Generate the DHCP packet let mut dhcp = DHCPPacket { - udp: udp_packet, - op_code: 1, // 1 for is request - hardware_type: 1, // ethernet is 1 - hardware_address_length: 6, // corresponds with hardware address length - hops: 0, // must be 0 from client - transaction_identifier: identification, // random ID - seconds: Bytefield16::new(0), // doesn't matter + udp: udp_layer, + op_code: 1, // 1 for is request + hardware_type: 1, // ethernet is 1 + hardware_address_length: 6, // corresponds with hardware address length + hops: 0, // must be 0 from client + transaction_identifier: identification, // random ID + seconds: Bytefield16::new(0), // doesn't matter flags: Bytefield16::new(if ip_address.is_none() { 0x1 } else { 0x0 }), // 1 indicates that I don't know my IP address client_ip: Bytefield32::new(ip_address.unwrap_or(0)), // 0 since my_ip: Bytefield32::new(0), // 4 bytes @@ -92,12 +123,12 @@ impl DHCPPacket { client_hardware_address, // 16 bytes sname: [0; 64], // 64 bytes file: [0; 128], // 128 bytes - options: [0; 64], // todo: 64 bytes (can be more) - // 300 bytes total + options: [0; 64], // todo: 64 bytes (can be more) + // 300 bytes total }; + // Calculate the checksums let data = dhcp.serialize(); - let start_udp = - data.len() - (DHCPPacket::packet_size() as usize + UDPPacket::packet_size() as usize); + let start_udp = data.len() - (DHCPPacket::packet_size() as usize + UDPPacket::packet_size() as usize); let start_ip = start_udp - (IPPacket::packet_size() as usize); dhcp.udp.ip.calculate_checksum(&data[start_ip..start_udp]); dhcp.udp.calculate_checksum(&data[start_udp..]); @@ -106,14 +137,22 @@ impl DHCPPacket { } impl Layer for DHCPPacket { + /// The input layer for parse type Input = UDPPacket; + + /// Parsing a DHCP packet requires: + /// - udp_layer: a parsed UDP packet + /// - bytevec: the data to parse, with trailing but it starts where the packet must begin fn parse(udp_layer: UDPPacket, bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized, { + // Create a packet let mut packet = DHCPPacket::new(); // create an empty packet let mut i = 0; + // Fill in udp layer packet.udp = udp_layer; + // Read in with bytefields packet.op_code = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.hardware_type = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); @@ -130,13 +169,18 @@ impl Layer for DHCPPacket { i += 128; // file i += 64; // options // Ignoring those parts of the DHCP packet. + // Get data left to parse (might be variable length DHCP packet) let left_to_parse = packet.udp.length.val() - 308; + // Increment I and assert its at least 300 i += left_to_parse as usize; assert!(i >= 300); // 300 bytes + // Return the packet, the amount of data consumed, and the next layer type (end of parse) (packet, i, LayerType::END) } + /// Serialize the packet into a vector of bytes, ready to send over the network fn serialize(&self) -> alloc::vec::Vec { + // Create a vector and serialize it let mut res = vec![]; res.extend(self.udp.serialize()); res.push(self.op_code); @@ -154,10 +198,12 @@ impl Layer for DHCPPacket { res.extend(self.sname); res.extend(self.file); res.extend(self.options); + // Assert the length is as expected assert!(res.len() == (300 + self.udp.serialize().len())); res } + /// The amount of data that belongs to the packet-type fn packet_size() -> u16 { 300 } diff --git a/kernel/src/network/ethernet.rs b/kernel/src/network/ethernet.rs index 32a530c..b398877 100644 --- a/kernel/src/network/ethernet.rs +++ b/kernel/src/network/ethernet.rs @@ -5,6 +5,8 @@ use super::{ use alloc::vec; use alloc::vec::Vec; +/// Ethernet type for the packet +/// - Mainly used for ARP or IPv4 #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[repr(u16)] pub enum EthType { @@ -15,6 +17,7 @@ pub enum EthType { } impl EthType { + /// Generate an EthType from a value pub fn from(packet_type: u16) -> Self { match packet_type { 0x0806 => Self::Arp, @@ -23,12 +26,13 @@ impl EthType { } } + /// Convert the enum to a bytefield pub fn as_bytefield(&self) -> Bytefield16 { Bytefield16::new(*self as u16) } } -// Total size is 14 bytes +/// An ethernet packet, implements Layer (14 bytes) #[derive(Debug, Clone)] pub struct EthernetPacket { pub dest_mac: Bytefield48, // u48 @@ -37,6 +41,7 @@ pub struct EthernetPacket { } impl EthernetPacket { + /// Create an empty packet with all 0s pub fn new() -> Self { EthernetPacket { dest_mac: Bytefield48::new(0), @@ -45,6 +50,10 @@ impl EthernetPacket { } } + /// Generate a Ethernet packet with + /// - destination_mac: the destination mac address + /// - source_mac: the source mac address + /// - packet_type: the class of packet to send pub fn gen(destination_mac: u64, source_mac: u64, packet_type: EthType) -> Self { EthernetPacket { dest_mac: Bytefield48::new(destination_mac), @@ -55,7 +64,12 @@ impl EthernetPacket { } impl Layer for EthernetPacket { + /// The input layer for parse type Input = EmptyLayer; + + /// Parsing an Ethernet packet requires: + /// - _empty_layer: A unused placeholder to fit the Layer paradigm + /// - bytevec: the data to parse, with trailing but it starts where the packet must begin fn parse(_empty_layer: EmptyLayer, bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized, @@ -66,17 +80,21 @@ impl Layer for EthernetPacket { packet.dest_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); packet.src_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); packet.packet_type = EthType::from(Bytefield16::read_inc(&bytevec[i..], &mut i).val()); - assert!(i == 14); // 14 bytes + assert!(i == 14); // assert 14 bytes + // Match on the packet type to make sure we don't have an error let layer_type = match &packet.packet_type { EthType::Arp => LayerType::ARP, EthType::IPv4 => LayerType::IP, EthType::RoCE => LayerType::ERR, EthType::Unknown => LayerType::ERR, }; + // Return the packet, the amount of data consumed, and the next layer type (based on EthType) (packet, i, layer_type) } + /// Serialize the packet into a vector of bytes, ready to send over the network fn serialize(&self) -> Vec { + // Create a vector and serialize it let mut res = vec![]; res.extend(self.dest_mac.data); res.extend(self.src_mac.data); @@ -85,6 +103,7 @@ impl Layer for EthernetPacket { res } + /// The amount of data that belongs to the packet-type fn packet_size() -> u16 { 14 } diff --git a/kernel/src/network/icmp.rs b/kernel/src/network/icmp.rs index e69de29..e973b82 100644 --- a/kernel/src/network/icmp.rs +++ b/kernel/src/network/icmp.rs @@ -0,0 +1 @@ +// TODO: \ No newline at end of file diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index 9f15275..b6d641d 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -6,88 +6,95 @@ use crate::network::dhcp::DHCPPacket; use crate::network::ethernet::{EthType, EthernetPacket}; use crate::network::ip::{IPPacket, Protocol}; use crate::network::layer::{Layer, LayerType}; -use crate::network::rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}; use crate::network::raw_socket::RawSocket; +use crate::network::rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}; use crate::network::udp::UDPPacket; -use crate::task::Task; use crate::task::executor::Executor; +use crate::task::Task; use crate::{network::constants::DHCP_SERVER_PORT, println}; +/// Initialize all the network related stuff pub fn init() { // todo bundle the init phases } +/// Initialize dhcp by finding my IP address pub async fn init_dhcp(wait_timeout: u8) -> bool { // open a socket (shouldn't fail since we are in the init process) let mut socket = RawSocket::new(DHCP_CLIENT_PORT, 1).unwrap(); + // Get the network driver object disable_network_interrupts(); let rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); - // send dhcp initial request - let eth = EthernetPacket::gen( - BROADCAST_MAC, - rtl_dev_info.mac_address.unwrap(), - EthType::IPv4, - ); + // create dhcp initial request + let eth = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::IPv4); let ip_size = DHCPPacket::packet_size() + UDPPacket::packet_size(); let ip = IPPacket::gen(eth, ip_size, Protocol::UDP, 0x0, BROADCAST_ADDR); - let udp = UDPPacket::gen( - ip, - DHCP_CLIENT_PORT as u16, - DHCP_SERVER_PORT as u16, - DHCPPacket::packet_size(), - ); + let udp = UDPPacket::gen(ip, DHCP_CLIENT_PORT as u16, DHCP_SERVER_PORT as u16, DHCPPacket::packet_size()); let dhcp = DHCPPacket::gen(udp, None, rtl_dev_info.mac_address.unwrap()); let packet_data = dhcp.serialize(); - rtl_dev_info.send_packet(&packet_data); // send first packet + // Send the first BOOTP packet + rtl_dev_info.send_packet(&packet_data); + // Release the driver object drop(rtl_dev_guard); enable_network_interrupts(); - // get response - let mut timeout = 0; + // Wait for response + let mut retries = 0; let pkt_data; loop { + // First poll on the raw socket if let Some(dhcp_res) = socket.next().await { if dhcp_res.is_err() { // If we got a socket error, we must return false return false; } + // Unwrap the data and break pkt_data = dhcp_res.unwrap(); break; } else { + // Resend the packet after a timeout or other error disable_network_interrupts(); - { let rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); rtl_dev_info.send_packet(&packet_data); // send another packet - } + drop(rtl_dev_guard); enable_network_interrupts(); } - timeout += 1; - if timeout == wait_timeout * 18 { + // Don't try forever + retries += 1; + if retries == wait_timeout * 18 { socket.close(); return false; } } + // Close the raw socket socket.close(); + // Get the driver object again disable_network_interrupts(); let mut rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_mut().unwrap(); if pkt_data.get_type() == LayerType::DHCP { + // With the driver object, record the information from the DHCP server let dhcp_res = pkt_data.unwrap_dhcp(); rtl_dev_info.my_ip_address = Some(dhcp_res.my_ip.val()); + rtl_dev_info.dhcp_server_ip = Some(dhcp_res.server_ip.val()); + // Debug print my IP let ip = dhcp_res.my_ip.swapped_endianness(); println!("[INFO] IP-Address Assigned As {}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]); - rtl_dev_info.dhcp_server_ip = Some(dhcp_res.server_ip.val()); } + // Release the driver object + drop(rtl_dev_guard); enable_network_interrupts(); true } +/// Spawn the task for processing packets from the packet queue +/// This *must* be spawned as soon as possible pub fn init_process_packet_data(exec: &mut Executor) { exec.spawn(Task::new(processing::init_packet_processing())); -} \ No newline at end of file +} diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index df63b3d..8b60680 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -75,13 +75,7 @@ impl IPPacket { } } - pub fn gen( - ethernet_packet: EthernetPacket, - data_length: u16, - protocol: Protocol, - src_ip: u32, - dst_ip: u32, - ) -> Self { + pub fn gen(ethernet_packet: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dst_ip: u32) -> Self { let identification = unsafe { let mut id_gen = ID_GEN.lock(); let id_gen_old = id_gen.get(); diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index 95e2f7a..dcfaf2f 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -147,55 +147,49 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { match next_type { LayerType::ETH => { let last_layer_data = last_layer.unwrap_undef(); - let (eth_layer, size, network_layer_type) = - EthernetPacket::parse(last_layer_data, &packet[i..]); + let (eth_layer, size, network_layer_type) = EthernetPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::ETH(eth_layer); i += size; next_type = network_layer_type; - }, + } LayerType::IP => { let last_layer_data = last_layer.unwrap_eth(); - let (ip_layer, size, transport_layer_type) = - IPPacket::parse(last_layer_data, &packet[i..]); + let (ip_layer, size, transport_layer_type) = IPPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::IP(ip_layer); i += size; next_type = transport_layer_type; - }, + } LayerType::ARP => { let last_layer_data = last_layer.unwrap_eth(); - let (arp_layer, size, transport_layer_type) = - ArpPacket::parse(last_layer_data, &packet[i..]); + let (arp_layer, size, transport_layer_type) = ArpPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::ARP(arp_layer); i += size; next_type = transport_layer_type; - }, + } LayerType::UDP => { let last_layer_data = last_layer.unwrap_ip(); - let (udp_layer, size, application_layer_type) = - UDPPacket::parse(last_layer_data, &packet[i..]); + let (udp_layer, size, application_layer_type) = UDPPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::UDP(udp_layer); i += size; next_type = application_layer_type; - }, + } LayerType::ICMP => { return (0, PacketData::UNDEF(EmptyLayer::new())); - }, + } LayerType::DHCP => { let last_layer_data = last_layer.unwrap_udp(); - let (dhcp_layer, size, empty_type) = - DHCPPacket::parse(last_layer_data, &packet[i..]); + let (dhcp_layer, size, empty_type) = DHCPPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::DHCP(dhcp_layer); i += size; next_type = empty_type; - }, + } LayerType::TCP => { let last_layer_data = last_layer.unwrap_ip(); - let (tcp_layer, size, empty_type) = - TCPPacket::parse(last_layer_data, &packet[i..]); + let (tcp_layer, size, empty_type) = TCPPacket::parse(last_layer_data, &packet[i..]); last_layer = PacketData::TCP(tcp_layer); i += size; next_type = empty_type; - }, + } LayerType::ERR => { return (0, PacketData::ERR(EmptyLayer::new())); } diff --git a/kernel/src/network/mod.rs b/kernel/src/network/mod.rs index 2ab317e..6aec4d6 100644 --- a/kernel/src/network/mod.rs +++ b/kernel/src/network/mod.rs @@ -7,15 +7,15 @@ pub mod ethernet; pub mod init; pub mod ip; pub mod layer; -pub mod rtl8139; pub mod raw_socket; -pub mod udp; +pub mod rtl8139; pub mod socket; pub mod tcp; mod tcp_session; +pub mod udp; // todo: remove pub until things break... mod arp_table; pub mod constants; mod netsync; -mod raw_array; mod processing; +mod raw_array; diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs index 0f88bfe..79614e6 100644 --- a/kernel/src/network/netsync.rs +++ b/kernel/src/network/netsync.rs @@ -1,10 +1,8 @@ use spin::MutexGuard; -use super::rtl8139::{RTL8139, NetworkConfig}; +use super::rtl8139::{NetworkConfig, RTL8139}; -struct InterruptGuard { - -} +struct InterruptGuard {} pub struct NetworkInterruptsGuard<'a> { data: MutexGuard<'a, Option>, @@ -38,12 +36,10 @@ impl SafeRTL8139 { pub fn new(data: spin::Mutex>, config: spin::Mutex) -> Self { Self { data, config } } - + pub fn lock(&self) -> NetworkInterruptsGuard { // disable_network_interrupts(); - return NetworkInterruptsGuard { - data: self.data.lock(), - }; + return NetworkInterruptsGuard { data: self.data.lock() }; } pub fn lock_no_disable(&self) -> MutexGuard> { return self.data.lock(); diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 146dbec..ef6ab6a 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -7,14 +7,14 @@ use crate::{ network::{ arp::ArpPacket, arp_table::ArpEntry, - constants::{ARP_PORT, BROADCAST_ADDR, TCP_SYN, TCP_FIN, TCP_ACK}, + constants::{ARP_PORT, BROADCAST_ADDR, TCP_ACK, TCP_FIN, TCP_SYN}, ethernet::{EthType, EthernetPacket}, ip::{IPPacket, Protocol}, layer::{Layer, LayerType, PacketData}, raw_socket::{wake_sockets, NetworkErrors}, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, tcp::TCPPacket, - tcp_session::{TCPSession, SessionAction}, + tcp_session::{SessionAction, TCPSession}, }, println, }; @@ -72,9 +72,7 @@ pub async fn init_packet_processing() { let mut raw_packets = PendingProcessingStream::new(); while let Some(pkt_data) = raw_packets.next().await { let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); - if amount_parsed_and_pkt.1.get_type() == LayerType::ERR - || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP - { + if amount_parsed_and_pkt.1.get_type() == LayerType::ERR || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP { // Don't deal with unrecognized packets (will fail assert) continue; } @@ -89,11 +87,7 @@ pub async fn init_packet_processing() { PacketData::ARP(arp) => { // todo: also check for broadcast if arp.recp_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { - let eth_layer = EthernetPacket::gen( - arp.sender_mac.val(), - rtl_dev_info.mac_address.unwrap(), - EthType::Arp, - ); + let eth_layer = EthernetPacket::gen(arp.sender_mac.val(), rtl_dev_info.mac_address.unwrap(), EthType::Arp); let arp_layer = ArpPacket::gen( eth_layer, rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR), @@ -114,11 +108,7 @@ pub async fn init_packet_processing() { if !rtl_dev_info.ports.contains_key(&ARP_PORT) { rtl_dev_info.ports.insert(ARP_PORT, VecDeque::new()); } - rtl_dev_info - .ports - .get_mut(&ARP_PORT) - .unwrap() - .push_back(Ok(PacketData::ARP(arp))); + rtl_dev_info.ports.get_mut(&ARP_PORT).unwrap().push_back(Ok(PacketData::ARP(arp))); wake_sockets(ARP_PORT); } } @@ -132,11 +122,7 @@ pub async fn init_packet_processing() { if !rtl_dev_info.ports.contains_key(&dst_port) { rtl_dev_info.ports.insert(dst_port, VecDeque::new()); } - rtl_dev_info - .ports - .get_mut(&dst_port) - .unwrap() - .push_back(Ok(PacketData::DHCP(dhcp))); + rtl_dev_info.ports.get_mut(&dst_port).unwrap().push_back(Ok(PacketData::DHCP(dhcp))); wake_sockets(dst_port); } } @@ -147,11 +133,7 @@ pub async fn init_packet_processing() { if !rtl_dev_info.ports.contains_key(&dst_port) { rtl_dev_info.ports.insert(dst_port, VecDeque::new()); } - rtl_dev_info - .ports - .get_mut(&dst_port) - .unwrap() - .push_back(Ok(PacketData::UDP(udp))); + rtl_dev_info.ports.get_mut(&dst_port).unwrap().push_back(Ok(PacketData::UDP(udp))); wake_sockets(dst_port); } } @@ -164,11 +146,7 @@ pub async fn init_packet_processing() { if !rtl_dev_info.ports.contains_key(&dst_port) { rtl_dev_info.ports.insert(dst_port, VecDeque::new()); } - let session_key = TCPSession::gen_session_key( - tcp.ip.source_ip.val(), - tcp.src_port.val(), - tcp.dest_port.val(), - ); + let session_key = TCPSession::gen_session_key(tcp.ip.source_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); // Open up a session if !rtl_dev_config.tcp_sessions.contains_key(&session_key) { if (tcp.get_flags() & TCP_SYN) == 0 { @@ -177,11 +155,7 @@ pub async fn init_packet_processing() { return; } // Compact the first packet we receive as the session creation - let eth_layer = EthernetPacket::gen( - tcp.ip.eth.src_mac.val(), - tcp.ip.eth.dest_mac.val(), - EthType::IPv4, - ); + let eth_layer = EthernetPacket::gen(tcp.ip.eth.src_mac.val(), tcp.ip.eth.dest_mac.val(), EthType::IPv4); let ip_layer = IPPacket::gen( eth_layer, 0, // leaving size undefined for the template @@ -189,14 +163,8 @@ pub async fn init_packet_processing() { tcp.ip.destination_ip.val(), tcp.ip.source_ip.val(), ); - let tcp_layer = - TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); - let session = TCPSession::new( - tcp_layer, - tcp.ip.source_ip.val(), - tcp.src_port.val(), - tcp.dest_port.val(), - ); + let tcp_layer = TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); + let session = TCPSession::new(tcp_layer, tcp.ip.source_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); // Lets push to the port -- we are listening then we need to create a new session rtl_dev_info .ports @@ -204,9 +172,7 @@ pub async fn init_packet_processing() { .unwrap() .push_back(Ok(PacketData::TCP(tcp.clone()))); wake_sockets(dst_port); - rtl_dev_config - .tcp_sessions - .insert(session.session_key(), session); + rtl_dev_config.tcp_sessions.insert(session.session_key(), session); } // todo: SYN COOKIES? // todo: Should I have a buffer limit? -- @@ -239,11 +205,7 @@ pub async fn init_packet_processing() { } else { unreachable!(); }; - rtl_dev_info - .ports - .get_mut(&session_key) - .unwrap() - .push_back(res); + rtl_dev_info.ports.get_mut(&session_key).unwrap().push_back(res); wake_sockets(session_key); } } diff --git a/kernel/src/network/raw_array.rs b/kernel/src/network/raw_array.rs index be4f8f1..d6c93dc 100644 --- a/kernel/src/network/raw_array.rs +++ b/kernel/src/network/raw_array.rs @@ -51,11 +51,7 @@ pub struct WrappingRawArray { impl WrappingRawArray { /// An infinite array beginning at "start" and wrapping after size bytes pub fn new(start: *const u8, size: usize) -> Self { - WrappingRawArray { - start, - pos: 0, - size, - } + WrappingRawArray { start, pos: 0, size } } // Ignore values diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index 842f07d..ca0f651 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -1,7 +1,10 @@ -use futures_util::{Stream, task::AtomicWaker}; +use futures_util::{task::AtomicWaker, Stream}; use hashbrown::HashMap; -use crate::{task::timeout::{register_timeout, cancel_timeout, TimeoutID}, println}; +use crate::{ + println, + task::timeout::{cancel_timeout, register_timeout, TimeoutID}, +}; use super::{ layer::PacketData, @@ -29,7 +32,7 @@ pub enum NetworkErrors { pub struct RawSocket { port: u64, timeout_in_epochs: u16, - timeout_active: bool, + timeout_active: bool, timeout_id: TimeoutID, } @@ -48,7 +51,12 @@ impl RawSocket { // and allocate a waker NEW_PACKET_WAKER.lock().insert(port, AtomicWaker::new()); enable_network_interrupts(); - Ok(RawSocket { port, timeout_in_epochs, timeout_active: false, timeout_id: TimeoutID::new(), }) + Ok(RawSocket { + port, + timeout_in_epochs, + timeout_active: false, + timeout_id: TimeoutID::new(), + }) } fn try_get_packet_inner(&self) -> Option> { @@ -74,7 +82,7 @@ impl RawSocket { vec.unwrap().clear(); } // tcp session information is removed by the socket (not the raw socket) - // since the socket is typed (tcp or udp) and the raw_socket purpose is just + // since the socket is typed (tcp or udp) and the raw_socket purpose is just // for polling for packets enable_network_interrupts(); } @@ -110,4 +118,4 @@ impl Stream for RawSocket { pub(crate) fn wake_sockets(port: u64) { // wake the port up NEW_PACKET_WAKER.lock()[&port].wake(); -} \ No newline at end of file +} diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 2ce871c..8ed669a 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -16,14 +16,10 @@ use x86_64::{ use crate::interrupts::IDT; -use crate::network::constants::{ - CAPR, CR, CR_BUFE, CR_RE, CR_TE, RX_BUFFER_SIZE, RX_READ_PTR_MASK, -}; +use crate::network::constants::{CAPR, CR, CR_BUFE, CR_RE, CR_TE, RX_BUFFER_SIZE, RX_READ_PTR_MASK}; use crate::network::raw_array::WrappingRawArray; -use super::constants::{ - INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG, -}; +use super::constants::{INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; use super::processing::add_pkt_data; use super::raw_socket::NetworkErrors; use super::tcp_session::TCPSession; @@ -36,10 +32,7 @@ use super::{ use crate::{ interrupts::{InterruptHandler, PICS}, memory::BootInfoFrameAllocator, - network::{ - devices, - netsync::SafeRTL8139, - }, + network::{devices, netsync::SafeRTL8139}, println, }; @@ -63,8 +56,7 @@ lazy_static! { }; } -static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = - spin::Mutex::new(InterruptCounter { data: 0 }); +static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = spin::Mutex::new(InterruptCounter { data: 0 }); // Disable network interrupts (is thread safe) pub fn disable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; @@ -141,9 +133,8 @@ fn recv_packet(rtl_dev_info: &RTL8139) { if unsafe { cmd_port.read() } & CR_BUFE == 0x0 { // Receive a packet by reading the buffer // ? Reading the buffer is naturally unsafe? Is there a better way? - let virtual_buffer_recv: VirtAddr = VirtAddr::new( - rtl_dev_info.recv_buffer.unwrap().as_u64() + rtl_dev_info.physical_mem_offset.unwrap(), - ); + let virtual_buffer_recv: VirtAddr = + VirtAddr::new(rtl_dev_info.recv_buffer.unwrap().as_u64() + rtl_dev_info.physical_mem_offset.unwrap()); // todo: check for packet validity https://www.cs.usfca.edu/~cruse/cs326f04/RTL8139_ProgrammersGuide.pdf let vb_recv_start: *const u8 = virtual_buffer_recv.as_ptr(); let mut rx_buffer = WrappingRawArray::new(vb_recv_start, RX_BUFFER_SIZE as usize); @@ -210,16 +201,11 @@ impl NetworkConfig { tcp_sessions: HashMap::with_capacity(10), } } - } impl RTL8139 { // Initialize the card - pub fn init( - &mut self, - frame_allocator: &mut BootInfoFrameAllocator, - physical_mem_offset: u64, - ) -> bool { + pub fn init(&mut self, frame_allocator: &mut BootInfoFrameAllocator, physical_mem_offset: u64) -> bool { let mut frames: Vec = vec![]; // create recv and send buffer space for _ in 0..6 { @@ -240,14 +226,8 @@ impl RTL8139 { next_start = frames[i].start_address() + frames[i].size(); } // Initialize data members - println!( - "[INFO] Setting up receive buffer at {}", - frames[0].start_address().as_u64() - ); - println!( - "[INFO] Setting up send buffer at {}", - frames[3].start_address().as_u64() - ); + println!("[INFO] Setting up receive buffer at {}", frames[0].start_address().as_u64()); + println!("[INFO] Setting up send buffer at {}", frames[3].start_address().as_u64()); self.send_buffer = Some(frames[0].start_address()); self.recv_buffer = Some(frames[3].start_address()); self.physical_mem_offset = Some(physical_mem_offset); @@ -263,10 +243,7 @@ impl RTL8139 { fn new(configs: Vec) -> Option { let mut use_dev = None; for dev in configs.iter() { - if dev.vendor_id == RTL_VEND as u16 - && dev.device_id == RTL_DEV as u16 - && dev.class_code == PCIClassCodes::NetworkController - { + if dev.vendor_id == RTL_VEND as u16 && dev.device_id == RTL_DEV as u16 && dev.class_code == PCIClassCodes::NetworkController { use_dev = Some(dev.clone()); } } @@ -326,8 +303,7 @@ impl RTL8139 { } // Register the interrupt handler for the card - IDT.lock() - .register_irq(self.config.irq.unwrap() as usize, network_handle); + IDT.lock().register_irq(self.config.irq.unwrap() as usize, network_handle); // Get MAC address let mac_addr = self.config.io_base.unwrap(); // + 0 offset @@ -398,8 +374,7 @@ impl RTL8139 { return; } - let virtual_buffer: VirtAddr = - VirtAddr::new(self.send_buffer.unwrap().as_u64() + self.physical_mem_offset.unwrap()); + let virtual_buffer: VirtAddr = VirtAddr::new(self.send_buffer.unwrap().as_u64() + self.physical_mem_offset.unwrap()); let virtual_buffer_ptr: *mut u8 = virtual_buffer.as_mut_ptr(); // Copy data into buffer for (i, item) in packet_data.iter().enumerate() { diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 4c9c93d..10557a0 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -6,7 +6,7 @@ use futures_util::StreamExt; use crate::{ network::{constants::TCP_ACK, layer::LayerType, tcp::TCPPacket}, - println, print, + print, println, }; use super::{ @@ -35,11 +35,7 @@ impl NetworkQuery { } } // send arp packet - let eth_layer = EthernetPacket::gen( - BROADCAST_MAC, - rtl_dev_info.mac_address.unwrap(), - EthType::Arp, - ); + let eth_layer = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::Arp); let arp_layer = ArpPacket::gen(eth_layer, rtl_dev_info.my_ip_address.unwrap(), ip, true); rtl_dev_info.send_packet(&arp_layer.serialize()); drop(rtl_dev_info_locked); @@ -106,11 +102,7 @@ pub struct Socket { impl Socket { // Can't send yet -- must listen - pub async fn open( - socket_type: SocketType, - src_port: u16, - wait_timeout: u16, - ) -> Result { + pub async fn open(socket_type: SocketType, src_port: u16, wait_timeout: u16) -> Result { let mut chosen_src_port = src_port; disable_network_interrupts(); let rtl_dev_info_locked = NET_INFO.lock(); @@ -183,10 +175,8 @@ impl Socket { let tcp_pkt = pkt.unwrap_tcp(); let dest_address = tcp_pkt.ip.source_ip.val(); let dest_port = tcp_pkt.src_port.val(); - let session_key = - TCPSession::gen_session_key(dest_address, dest_port, self.src_port); - let raw_socket = - RawSocket::new(session_key, max(self.wait_timeout * 18, 1)).unwrap(); + let session_key = TCPSession::gen_session_key(dest_address, dest_port, self.src_port); + let raw_socket = RawSocket::new(session_key, max(self.wait_timeout * 18, 1)).unwrap(); println!("[INFO] Spawned new TCP session"); return Some(Socket { socket_type: SocketType::TCP, @@ -273,7 +263,7 @@ impl Socket { let session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); let fin_ack_pkt = session.close(); drop(rtl_dev_config); - if let Ok(pkt) = fin_ack_pkt { + if let Ok(pkt) = fin_ack_pkt { // With 6 retries 'outer: for _ in 0..6 { if NET_INFO.config.lock().tcp_sessions.get_mut(&session_key).unwrap().is_closed() { @@ -375,10 +365,7 @@ impl Socket { // Set up to receive ack let mut tcp_session_guard = NET_INFO.config.lock(); - let tcp_session = tcp_session_guard - .tcp_sessions - .get_mut(&self.session_key) - .unwrap(); + let tcp_session = tcp_session_guard.tcp_sessions.get_mut(&self.session_key).unwrap(); let message_pkt = tcp_session.process_send(data); if let Err(err) = message_pkt { return Err(err); @@ -407,10 +394,7 @@ impl Socket { // let pkt_data = pkt.unwrap_tcp(); // Check the acknowledgement to make sure everything is acked let mut tcp_session_guard = NET_INFO.config.lock(); - let tcp_session = tcp_session_guard - .tcp_sessions - .get_mut(&self.session_key) - .unwrap(); + let tcp_session = tcp_session_guard.tcp_sessions.get_mut(&self.session_key).unwrap(); if tcp_session.everything_acked() { // if so break out of the timeout loop break; @@ -433,31 +417,18 @@ impl Socket { } let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); let udp_size = UDPPacket::packet_size() + data.len() as u16; - let ip_layer = IPPacket::gen( - eth_layer, - udp_size, - Protocol::UDP, - self.src_address, - self.dest_ip, - ); + let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::UDP, self.src_address, self.dest_ip); let data_len = data.len(); - let mut udp_layer = - UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); + let mut udp_layer = UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); udp_layer.data = data.to_vec(); // todo: split the data and return the amount actually written! let data_2_send = udp_layer.serialize(); let start_udp = data_2_send.len() - (UDPPacket::packet_size() as usize + data_len); let start_ip = start_udp - (IPPacket::packet_size() as usize); - udp_layer - .ip - .calculate_checksum(&data_2_send[start_ip..start_udp]); + udp_layer.ip.calculate_checksum(&data_2_send[start_ip..start_udp]); udp_layer.calculate_checksum(&data_2_send[start_udp..]); let data_2_send_final = udp_layer.serialize(); disable_network_interrupts(); - NET_INFO - .lock() - .get_ref() - .unwrap() - .send_packet(&data_2_send_final); + NET_INFO.lock().get_ref().unwrap().send_packet(&data_2_send_final); enable_network_interrupts(); Ok(data_len as u16) } diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index 4a672a2..6b7f3f2 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -115,9 +115,7 @@ impl Layer for TCPPacket { i += 1; } assert!(i == packet.get_header_offset().into()); // Valid i is header_length - let data_size = packet.ip.total_length.val() - - packet.get_header_offset() as u16 - - IPPacket::packet_size(); + let data_size = packet.ip.total_length.val() - packet.get_header_offset() as u16 - IPPacket::packet_size(); for _ in 0..data_size { packet.data.push(bytevec[i]); i += 1; @@ -138,9 +136,7 @@ impl Layer for TCPPacket { res.extend(self.urgent.data); res.extend(&self.options); res.extend(&self.data); - assert!( - res.len() == (20 + self.ip.serialize().len() + self.options.len() + self.data.len()) - ); + assert!(res.len() == (20 + self.ip.serialize().len() + self.options.len() + self.data.len())); res } diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index 04b11a0..d5e76c8 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -80,7 +80,7 @@ impl TCPSession { } pub fn is_closed(&self) -> bool { - self.session_state ==TCPSessionState::Closed || self.has_recv_ack_to_fin_ack && self.has_sent_ack_to_fin_ack + self.session_state == TCPSessionState::Closed || self.has_recv_ack_to_fin_ack && self.has_sent_ack_to_fin_ack } /// Create a packet to close the tcp session with @@ -94,9 +94,7 @@ impl TCPSession { tcp_pkt.turn_on_flags(TCP_FIN | TCP_ACK); // add the data size (maybe make this automatic?) - tcp_pkt.ip.total_length = Bytefield16::new( - TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size(), - ); + tcp_pkt.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); @@ -116,9 +114,7 @@ impl TCPSession { tcp_pkt.turn_on_flags(TCP_RST); // add the data size (maybe make this automatic?) - tcp_pkt.ip.total_length = Bytefield16::new( - TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size(), - ); + tcp_pkt.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); @@ -144,12 +140,8 @@ impl TCPSession { tcp_pkt.data = data.to_vec(); // todo: split the data and return the amount actually written! // add the data size (maybe make this automatic?) - tcp_pkt.ip.total_length = Bytefield16::new( - TCPPacket::packet_size() - + TCPPacket::options_size() - + IPPacket::packet_size() - + data.len() as u16, - ); + tcp_pkt.ip.total_length = + Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size() + data.len() as u16); tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); self.sent_data_amount += data.len() as u32; tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); @@ -174,9 +166,7 @@ impl TCPSession { self.window_size = request.sliding_window.val(); let mut response = self.session_template.clone(); // Setting length to be 24 [sizeof(TCP-header) + sizeof(options)] + 20 [sizeof(IP-header)] - response.ip.total_length = Bytefield16::new( - TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size(), - ); + response.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); // todo! check seq-num and ack-num before doing these state changes if has_fin_flag && self.session_state == TCPSessionState::Established { self.session_state = TCPSessionState::Closing; @@ -230,8 +220,7 @@ impl TCPSession { } TCPSessionState::Established => { let mut has_info = SessionAction::Drop; - if request.ack_num.val() > self.sent_data_acked - && request.ack_num.val() <= self.sent_data_amount { + if request.ack_num.val() > self.sent_data_acked && request.ack_num.val() <= self.sent_data_amount { // We are updating our record of how much data was acked self.sent_data_acked = request.ack_num.val(); has_info = SessionAction::PushUpstream; // todo: why do we need to push up an empty packet? @@ -262,7 +251,7 @@ impl TCPSession { } return (None, SessionAction::PushUpstream); } - + self.recv_data_amount += 1; response.ack_num = Bytefield32::new(self.recv_data_amount); response.seq_num = Bytefield32::new(self.sent_data_amount); diff --git a/kernel/src/process.rs b/kernel/src/process.rs index 0742d62..93c2fa8 100644 --- a/kernel/src/process.rs +++ b/kernel/src/process.rs @@ -62,38 +62,30 @@ impl Process { self.table_index } - fn resume_inner( - mut this: &mut wasmi::Store, - inputs: &[wasmi::Value], - ) -> Result<(), wasmi::Error> { - let (instance, resumable_call) = - match core::mem::replace(&mut this.data_mut().thread_state, ThreadState::Temp) { - ThreadState::NotStarted(instance_pre) => { - // Call the WASM start function if present: - // https://webassembly.github.io/spec/core/syntax/modules.html#syntax-start - // TODO: Set a hard gas limit for the non-resumable start() function - this.add_fuel(START_FUEL)?; - let instance = instance_pre.start(&mut this)?; - let entrypoint = instance.get_typed_func::<(), ()>(&this, "")?; - this.add_fuel(PREEMPT_FUEL)?; - let resumable_call = entrypoint.call_resumable(&mut this, ())?; - (instance, resumable_call) - } - ThreadState::InProgress(instance, invocation) => { - (instance, invocation.resume(&mut this, inputs)?) - } - ThreadState::Temp => unreachable!(), - ts @ (ThreadState::Finished(..) | ThreadState::Error(..)) => { - this.data_mut().thread_state = ts; - return Ok(()); - } - }; + fn resume_inner(mut this: &mut wasmi::Store, inputs: &[wasmi::Value]) -> Result<(), wasmi::Error> { + let (instance, resumable_call) = match core::mem::replace(&mut this.data_mut().thread_state, ThreadState::Temp) { + ThreadState::NotStarted(instance_pre) => { + // Call the WASM start function if present: + // https://webassembly.github.io/spec/core/syntax/modules.html#syntax-start + // TODO: Set a hard gas limit for the non-resumable start() function + this.add_fuel(START_FUEL)?; + let instance = instance_pre.start(&mut this)?; + let entrypoint = instance.get_typed_func::<(), ()>(&this, "")?; + this.add_fuel(PREEMPT_FUEL)?; + let resumable_call = entrypoint.call_resumable(&mut this, ())?; + (instance, resumable_call) + } + ThreadState::InProgress(instance, invocation) => (instance, invocation.resume(&mut this, inputs)?), + ThreadState::Temp => unreachable!(), + ts @ (ThreadState::Finished(..) | ThreadState::Error(..)) => { + this.data_mut().thread_state = ts; + return Ok(()); + } + }; this.data_mut().thread_state = match resumable_call { wasmi::TypedResumableCall::Finished(()) => ThreadState::Finished(instance), - wasmi::TypedResumableCall::Resumable(invocation) => { - ThreadState::InProgress(instance, invocation) - } + wasmi::TypedResumableCall::Resumable(invocation) => ThreadState::InProgress(instance, invocation), }; Ok(()) @@ -161,17 +153,12 @@ pub enum SchedulerError<'a> { Recoverable(&'a wasmi::core::Trap), } -pub fn demo_scheduler_resume( - process: &mut wasmi::Store, -) -> Result<(), SchedulerError<'_>> { +pub fn demo_scheduler_resume(process: &mut wasmi::Store) -> Result<(), SchedulerError<'_>> { let initial_fuel_consumed = process.fuel_consumed().unwrap(); match &process.data().thread_state { ThreadState::NotStarted(..) => Process::resume(process, &[]), ThreadState::InProgress(.., invocation) - if matches!( - invocation.host_error().trap_code(), - Some(wasmi::core::TrapCode::OutOfFuel) - ) => + if matches!(invocation.host_error().trap_code(), Some(wasmi::core::TrapCode::OutOfFuel)) => { if initial_fuel_consumed < TOTAL_FUEL_HARD_LIMIT { process.add_fuel(PREEMPT_FUEL).unwrap(); @@ -185,11 +172,7 @@ pub fn demo_scheduler_resume( ThreadState::InProgress(_, invocation) => match invocation.host_error().trap_code() { // Return Ok(()) if we made forward progress but it seems we hit the fuel limit for // preemption. Otherwise (for any "recoverable" but unrecognized error), report failure. - Some(wasmi::core::TrapCode::OutOfFuel) - if initial_fuel_consumed < process.fuel_consumed().unwrap() => - { - Ok(()) - } + Some(wasmi::core::TrapCode::OutOfFuel) if initial_fuel_consumed < process.fuel_consumed().unwrap() => Ok(()), _ => return Err(SchedulerError::Recoverable(invocation.host_error())), }, ThreadState::Error(e) => Err(SchedulerError::Unrecoverable(e)), diff --git a/kernel/src/serial.rs b/kernel/src/serial.rs index 6fb62c8..5e9798c 100644 --- a/kernel/src/serial.rs +++ b/kernel/src/serial.rs @@ -16,10 +16,7 @@ pub fn _print(args: ::core::fmt::Arguments) { use x86_64::instructions::interrupts; interrupts::without_interrupts(|| { - SERIAL1 - .lock() - .write_fmt(args) - .expect("Printing to serial failed"); + SERIAL1.lock().write_fmt(args).expect("Printing to serial failed"); }); } diff --git a/kernel/src/task/executor.rs b/kernel/src/task/executor.rs index 2785ec7..e9f270c 100644 --- a/kernel/src/task/executor.rs +++ b/kernel/src/task/executor.rs @@ -97,10 +97,7 @@ impl TaskWaker { #[allow(clippy::new_ret_no_self)] fn new(task_id: TaskId, task_queue: Arc>) -> Waker { - Waker::from(Arc::new(TaskWaker { - task_id, - task_queue, - })) + Waker::from(Arc::new(TaskWaker { task_id, task_queue })) } } diff --git a/kernel/src/task/keyboard.rs b/kernel/src/task/keyboard.rs index 77a8e83..eb36524 100644 --- a/kernel/src/task/keyboard.rs +++ b/kernel/src/task/keyboard.rs @@ -63,11 +63,7 @@ pub(crate) fn add_scancode(scancode: u8) { pub async fn print_keypresses() { let mut scancodes = ScancodeStream::new(); - let mut keyboard = Keyboard::new( - ScancodeSet1::new(), - layouts::Us104Key, - HandleControl::Ignore, - ); + let mut keyboard = Keyboard::new(ScancodeSet1::new(), layouts::Us104Key, HandleControl::Ignore); while let Some(scancode) = scancodes.next().await { if let Ok(Some(key_event)) = keyboard.add_byte(scancode) { diff --git a/kernel/src/task/mod.rs b/kernel/src/task/mod.rs index db1eeaf..ea13200 100644 --- a/kernel/src/task/mod.rs +++ b/kernel/src/task/mod.rs @@ -6,9 +6,9 @@ use core::{future::Future, pin::Pin}; pub mod executor; pub mod keyboard; pub mod simple_executor; -pub mod udp_echo; -pub mod timeout; pub mod tcp_echo; +pub mod timeout; +pub mod udp_echo; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] struct TaskId(u64); diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index 611680d..a521817 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -1,5 +1,9 @@ use crate::{ - print, println, network::{socket::{SocketType, Socket}, raw_socket::NetworkErrors}, + network::{ + raw_socket::NetworkErrors, + socket::{Socket, SocketType}, + }, + print, println, }; use alloc::string::String; @@ -30,7 +34,7 @@ pub async fn tcp_echo_server() { println!("Closed socket"); break; } - }, + } Err(err) => println!("[USER-ERR] {:?}", err), } let res_or_err = socket.write(&mut data).await; @@ -49,7 +53,7 @@ pub async fn tcp_echo_server() { } } } - }, + } Err(err) => println!("[ERR] (Listening): {:?}", err), } -} \ No newline at end of file +} diff --git a/kernel/src/task/timeout.rs b/kernel/src/task/timeout.rs index fb32813..9606fc3 100644 --- a/kernel/src/task/timeout.rs +++ b/kernel/src/task/timeout.rs @@ -1,5 +1,5 @@ -use core::{sync::atomic::AtomicU64, cell::RefCell, task::Waker}; use alloc::collections::BinaryHeap; +use core::{cell::RefCell, sync::atomic::AtomicU64, task::Waker}; use lazy_static::lazy_static; use x86_64::instructions::interrupts; @@ -24,7 +24,12 @@ struct TimeoutEntry { impl TimeoutEntry { pub fn new(id: TimeoutID, epochs: u64, waker: Waker) -> Self { - TimeoutEntry { id, epochs, waker, cancelled: false } + TimeoutEntry { + id, + epochs, + waker, + cancelled: false, + } } } @@ -61,7 +66,11 @@ pub fn register_timeout(after_epochs: u16, waker: Waker) -> TimeoutID { let timeout_id = TimeoutID::new(); interrupts::without_interrupts(|| { let mut timeout_queue = TIMEOUT_QUEUE.lock(); - timeout_queue.push(RefCell::new(TimeoutEntry::new(timeout_id, unsafe { INTERRUPT_COUNTER } + after_epochs as u64, waker))); + timeout_queue.push(RefCell::new(TimeoutEntry::new( + timeout_id, + unsafe { INTERRUPT_COUNTER } + after_epochs as u64, + waker, + ))); }); timeout_id } @@ -85,11 +94,11 @@ pub fn poll_timeouts() { while let Some(timeout_entry) = timeout_queue.peek() { if timeout_entry.borrow().epochs <= unsafe { INTERRUPT_COUNTER } { if !timeout_entry.borrow().cancelled { - timeout_entry.borrow().waker.wake_by_ref(); + timeout_entry.borrow().waker.wake_by_ref(); } timeout_queue.pop(); } else { break; } } -} \ No newline at end of file +} diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index 07840fb..9f2e117 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -1,5 +1,9 @@ use crate::{ - print, println, network::{socket::{SocketType, Socket}, raw_socket::NetworkErrors}, + network::{ + raw_socket::NetworkErrors, + socket::{Socket, SocketType}, + }, + print, println, }; use alloc::string::String; @@ -22,7 +26,7 @@ pub async fn udp_echo_server() { socket.close(); return; } - }, + } Err(err) => println!("[USER-ERR] {:?}", err), } let res_or_err = socket.write(&mut data).await; @@ -39,7 +43,7 @@ pub async fn udp_echo_server() { break; } } - }, + } Err(err) => println!("[ERR] {:?}", err), } -} \ No newline at end of file +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..93470b6 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 140 \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 206a7af..41a4356 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,19 +18,15 @@ fn main() { cmd.arg("-netdev") .arg("user,id=net0,hostfwd=udp::5555-:5554,hostfwd=tcp::6666-:6664"); // Making sure we have the rtl8139 as a hardware resource - cmd.arg("-device") - .arg("rtl8139,netdev=net0,mac=00:11:22:33:44:55"); + cmd.arg("-device").arg("rtl8139,netdev=net0,mac=00:11:22:33:44:55"); // Recording the network traffic in order to conduct debugging and analysis - cmd.arg("-object") - .arg("filter-dump,id=f1,netdev=net0,file=/tmp/dump.pcap"); + cmd.arg("-object").arg("filter-dump,id=f1,netdev=net0,file=/tmp/dump.pcap"); if uefi { cmd.arg("-bios").arg(ovmf_prebuilt::ovmf_pure_efi()); - cmd.arg("-drive") - .arg(format!("format=raw,file={uefi_path}")); + cmd.arg("-drive").arg(format!("format=raw,file={uefi_path}")); } else { - cmd.arg("-drive") - .arg(format!("format=raw,file={bios_path}")); + cmd.arg("-drive").arg(format!("format=raw,file={bios_path}")); } cmd.args(env::args_os().skip(1)); let mut child = cmd.spawn().unwrap(); From ab70c37f250da05bd42c8deabcec5228f44d75a2 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sat, 18 Nov 2023 23:42:36 -0500 Subject: [PATCH 18/36] Comments --- Cargo.lock | 61 +++++++ kernel/src/interrupts.rs | 13 +- kernel/src/network/README.md | 2 +- kernel/src/network/TODO.md | 1 + kernel/src/network/arp.rs | 9 +- kernel/src/network/errors.rs | 18 ++ kernel/src/network/ip.rs | 83 ++++++--- kernel/src/network/layer.rs | 63 ++++++- kernel/src/network/mod.rs | 30 ++-- kernel/src/network/netsync.rs | 25 ++- kernel/src/network/network_query.rs | 84 +++++++++ kernel/src/network/processing.rs | 142 +++++++++------- kernel/src/network/raw_array.rs | 66 ++------ kernel/src/network/raw_socket.rs | 54 +++--- kernel/src/network/rtl8139.rs | 7 +- kernel/src/network/socket.rs | 253 +++++++++++++++++----------- kernel/src/network/tcp.rs | 91 ++++++---- kernel/src/network/tcp_session.rs | 118 +++++++++---- kernel/src/network/udp.rs | 59 +++++-- kernel/src/task/executor.rs | 10 +- kernel/src/task/tcp_echo.rs | 88 +++++----- kernel/src/task/timeout.rs | 34 +++- kernel/src/task/udp_echo.rs | 77 +++++---- 23 files changed, 949 insertions(+), 439 deletions(-) create mode 100644 kernel/src/network/errors.rs create mode 100644 kernel/src/network/network_query.rs diff --git a/Cargo.lock b/Cargo.lock index a764925..d1edc6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,24 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ahash" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + [[package]] name = "anyhow" version = "1.0.75" @@ -485,6 +503,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "hashbrown" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +dependencies = [ + "ahash", + "allocator-api2", +] + [[package]] name = "hermit-abi" version = "0.3.3" @@ -531,6 +559,7 @@ dependencies = [ "conquer-once", "crossbeam-queue", "futures-util", + "hashbrown", "lazy_static", "linked_list_allocator", "noto-sans-mono-bitmap", @@ -642,6 +671,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + [[package]] name = "os" version = "0.1.0" @@ -1042,6 +1077,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "volatile" version = "0.2.7" @@ -1229,3 +1270,23 @@ dependencies = [ "rustversion", "volatile 0.4.6", ] + +[[package]] +name = "zerocopy" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd66a62464e3ffd4e37bd09950c2b9dd6c4f8767380fabba0d523f9a775bc85a" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "255c4596d41e6916ced49cfafea18727b24d67878fa180ddfd69b9df34fd1726" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index 0dab7c5..7d1340f 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -27,12 +27,15 @@ impl InterruptIndex { } } +/// An interrupt manager for registering and blocking/unblocking interrupts pub struct InterruptHandler { idt: InterruptDescriptorTable, } pub type InterruptHandlerFunc = extern "x86-interrupt" fn(InterruptStackFrame) -> (); impl InterruptHandler { + /// Create an interrupt manager + /// Will create a IDT and maintain it with helper functions belonging to this object pub fn new() -> Self { let mut idt = InterruptDescriptorTable::new(); idt.breakpoint.set_handler_fn(breakpoint_handler); @@ -47,11 +50,12 @@ impl InterruptHandler { InterruptHandler { idt } } + /// Initialize the interrupt handler pub fn init(&self) { unsafe { self.idt.load_unsafe() }; } - // Static function for disabling an irq + /// Static function for unblocking an interrupt by irq_num pub fn unblock_irq(irq_num: u8) { let mut locked_pics = PICS.lock(); let data = unsafe { locked_pics.read_masks() }; @@ -63,7 +67,7 @@ impl InterruptHandler { } } - // Static function for re-enabling an IRQ + /// Static function for blocking an interrupt by irq_num pub fn block_irq(irq_num: u8) { let mut locked_pics = PICS.lock(); let data = unsafe { locked_pics.read_masks() }; @@ -75,6 +79,8 @@ impl InterruptHandler { } } + /// Register an interrupt handler + /// Used by the RTL card to receive device interrupts pub fn register_irq(&mut self, irq_num: usize, handler: InterruptHandlerFunc) { println!("Registered Handler @ {}", irq_num + 32); self.idt[irq_num + 32].set_handler_fn(handler); @@ -84,9 +90,11 @@ impl InterruptHandler { } lazy_static! { + /// An interrupt manager object protected by a mutex pub static ref IDT: spin::Mutex = spin::Mutex::new(InterruptHandler::new()); } +/// Initialize the IDT pub fn init_idt() { IDT.lock().init(); } @@ -110,6 +118,7 @@ extern "x86-interrupt" fn double_fault_handler(stack_frame: InterruptStackFrame, } extern "x86-interrupt" fn timer_interrupt_handler(_stack_frame: InterruptStackFrame) { + // Poll for timeouts in the timer interrupt handler poll_timeouts(); unsafe { PICS.lock().notify_end_of_interrupt(InterruptIndex::Timer.as_u8()); diff --git a/kernel/src/network/README.md b/kernel/src/network/README.md index cb7362a..b234e74 100644 --- a/kernel/src/network/README.md +++ b/kernel/src/network/README.md @@ -12,13 +12,13 @@ TODO [x] Async IO [x] Timeouts [] Refactor so that all of networking is tested -[] Refactor to include more documentation on the network module [] Refactor to verify checksums [] Verify other parts of the packet [] Fix synchronization to be much cleaner [] Clean up ugly stuff [] Refactor to be all constants [] search for todo and fix thoses +[] Refactor to include more documentation on the network module ## Receiving packets diff --git a/kernel/src/network/TODO.md b/kernel/src/network/TODO.md index 8fd86bd..57101fc 100644 --- a/kernel/src/network/TODO.md +++ b/kernel/src/network/TODO.md @@ -11,3 +11,4 @@ * Refactor to be all constants * Search for todo and fix thoses * Benchmarking +* Rename ip and mac address to a standard \ No newline at end of file diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs index b298656..05efd9f 100644 --- a/kernel/src/network/arp.rs +++ b/kernel/src/network/arp.rs @@ -6,7 +6,7 @@ use super::{ use alloc::vec; use alloc::vec::Vec; -/// An arp packet, implements Layer (42 bytes) +/// An arp packet, implements Layer (28 bytes) #[derive(Debug)] pub struct ArpPacket { /// The parent packet @@ -85,9 +85,8 @@ impl Layer for ArpPacket { fn parse(eth_layer: EthernetPacket, bytevec: &[u8]) -> (Self, usize, LayerType) { let mut packet = ArpPacket::new(); // create an empty packet - // Read ethernet packet and 28 bytes + // Save ethernet packet and read 20 bytes let mut i = 0; - // Extract the eth_layer packet.eth = eth_layer; // Read byte by byte into the struct packet.hardware_type = Bytefield16::read_inc(&bytevec[i..], &mut i); @@ -100,7 +99,8 @@ impl Layer for ArpPacket { packet.recp_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); packet.recp_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); assert!(i == 28); // Arp packet should be 28 bytes - // Return the packet, the amount of data consumed, and the next layer type (end of parse) + + // Return the packet, the amount of data consumed, and the next layer type (end of parse) (packet, i, LayerType::END) } @@ -123,7 +123,6 @@ impl Layer for ArpPacket { /// The amount of data that belongs to the packet-type fn packet_size() -> u16 { - // 28 bytes 28 } } diff --git a/kernel/src/network/errors.rs b/kernel/src/network/errors.rs new file mode 100644 index 0000000..7627a5a --- /dev/null +++ b/kernel/src/network/errors.rs @@ -0,0 +1,18 @@ +/// Network related errors +#[derive(Debug, PartialEq, Eq)] +pub enum NetworkErrors { + /// The port is used by another socket + PortInUse, + /// No ports are open + NoAvailablePort, + /// The destination is not reachable (can't resolve with ARP) + NonexistentHost, + /// The socket is in a bad socket state and we cannot perform the operation for you + BadSocketState, + /// A placeholder Network Error for unimplemented features (for development) + FeatureNotAvailableYet, + /// If a timeout occured + Timeout, + /// This is a special network error for when our TCP stream has closed + ClosedSocket, +} \ No newline at end of file diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index 8b60680..d63b01c 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -7,17 +7,19 @@ use super::{ layer::{HasChecksum, Layer, LayerType}, }; +/// Protocol for IP +/// - ICMP (in development), TCP, UDP #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum Protocol { ICMP = 1, TCP = 6, UDP = 17, - RDP = 27, Unsupported = 255, } impl Protocol { + /// Parse a u8 into a Protocol enum pub fn from(data: u8) -> Self { match data { 1 => Self::ICMP, @@ -28,37 +30,55 @@ impl Protocol { } } +// todo: Replace for AtomicU16? +/// ID generator struct (a wrapped u16) struct WrappedU16 { data: u16, } impl WrappedU16 { + /// Get the u16 pub fn get(&self) -> u16 { self.data } + /// Set the u16 pub fn set(&mut self, data: u16) { self.data = data; } } +/// Atomic u16 for id generation static mut ID_GEN: spin::Mutex = spin::Mutex::new(WrappedU16 { data: 0 }); + +/// A IP packet, implements Layer and HasChecksum (20 bytes) #[derive(Debug, Clone)] pub struct IPPacket { + /// The parent packet pub eth: EthernetPacket, - version_hlen: u8, // 1 byte - type_of_service: u8, // 1 byte - pub total_length: Bytefield16, // 2 bytes (public for checksumming) - identification: Bytefield16, // 2 bytes - flags_fragment_offset: Bytefield16, // 2 bytes - ttl: u8, // 1 byte - pub protocol: Protocol, // 1 byte (public for checksumming) - pub checksum: Bytefield16, // 2 bytes - pub source_ip: Bytefield32, // 4 bytes (public for checksumming) - pub destination_ip: Bytefield32, // 4 bytes (public for checksumming) - // 20 bytes in total + /// IP version (hardcoded) + version_hlen: u8, + /// Can increase urgency. Unused + type_of_service: u8, + /// The entire size of IP packet + data size + pub total_length: Bytefield16, + /// A unique ID to help de-fragmentation of the packet + identification: Bytefield16, + /// Flags to prevent fragmentation (we are fine with it) + flags_fragment_offset: Bytefield16, + /// How many router hops before we drop the packet + ttl: u8, + /// The protocol used to identify the data + pub protocol: Protocol, + /// The checksum to verify contents of the packet + pub checksum: Bytefield16, + /// The sender's IP address + pub source_ip: Bytefield32, + /// The recepient's IP address + pub destination_ip: Bytefield32, } impl IPPacket { + /// Create an empty packet with all 0s pub fn new() -> Self { IPPacket { eth: EthernetPacket::new(), @@ -75,21 +95,29 @@ impl IPPacket { } } - pub fn gen(ethernet_packet: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dst_ip: u32) -> Self { + /// Generate a IP packet with + /// - eth_layer: is the ethernet frame associated with the packet + /// - data_length: is the data's size + /// - protocol: the protocol of the packet (TCP/UDP) + /// - src_ip: the sender's IP address + /// - dst_ip: the destination's IP address + pub fn gen(eth_layer: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dst_ip: u32) -> Self { + // Generate a unique ID for the packet let identification = unsafe { let mut id_gen = ID_GEN.lock(); let id_gen_old = id_gen.get(); id_gen.set((id_gen_old + 1) % 0xFFFF); Bytefield16::new(id_gen.get()) }; + // Construct the packet IPPacket { - eth: ethernet_packet, + eth: eth_layer, version_hlen: 0x45, type_of_service: 0x0, total_length: Bytefield16::new(data_length + 20), // adding data length and size of IP packet identification, flags_fragment_offset: Bytefield16::new(0), - ttl: 120, + ttl: 120, // 120 - our packet needs to make it there protocol, checksum: Bytefield16::new(0), source_ip: Bytefield32::new(src_ip), @@ -99,15 +127,20 @@ impl IPPacket { } impl Layer for IPPacket { + /// The input layer for parse type Input = EthernetPacket; - fn parse(ethernet_layer: EthernetPacket, bytevec: &[u8]) -> (Self, usize, LayerType) + + /// Parsing an IP packet requires: + /// - eth_layer: a parsed Ethernet packet + /// - bytevec: the data to parse, with trailing but it starts where the packet must begin + fn parse(eth_layer: EthernetPacket, bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized, { let mut packet = IPPacket::new(); // create an empty packet - // Read 20 bytes + // Save ethernet packet and read 20 bytes let mut i = 0; - packet.eth = ethernet_layer; + packet.eth = eth_layer; packet.version_hlen = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.type_of_service = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.total_length = Bytefield16::read_inc(&bytevec[i..], &mut i); @@ -119,18 +152,23 @@ impl Layer for IPPacket { packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.source_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); packet.destination_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); - assert!(i == 20); // 20 bytes + // assert 20 bytes + assert!(i == 20); + // Match the protocol to determine next layer let layer_type = match packet.protocol { - Protocol::ICMP => LayerType::ICMP, + Protocol::ICMP => LayerType::ICMP, // unsupported right now Protocol::TCP => LayerType::TCP, Protocol::UDP => LayerType::UDP, - Protocol::RDP => LayerType::ERR, Protocol::Unsupported => LayerType::ERR, }; + + // Return the packet, the amount of data consumed, and the next layer type (end of parse) (packet, i, layer_type) } + /// Serialize the packet into a vector of bytes, ready to send over the network fn serialize(&self) -> Vec { + // Create a vector and serialize it let mut res = vec![]; res.extend(self.eth.serialize()); res.push(self.version_hlen); @@ -147,12 +185,15 @@ impl Layer for IPPacket { res } + /// The amount of data that belongs to the packet-type fn packet_size() -> u16 { 20 } } impl HasChecksum for IPPacket { + /// Calculate a checksum on the data and the packet + /// - will self mutate fn calculate_checksum(&mut self, data: &[u8]) { // Starting vars let mut sum: u32 = 0; diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index dcfaf2f..f9995b7 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -1,4 +1,3 @@ -use alloc::vec; use alloc::vec::Vec; use super::arp::ArpPacket; @@ -8,7 +7,9 @@ use super::ip::IPPacket; use super::tcp::TCPPacket; use super::udp::UDPPacket; +/// A trait to define a packet pub trait Layer { + /// The upper layer to serve as input to the packet's parse function type Input; /// Parse a bytevec (vector slice) and return the packet and the amount of data parsed @@ -22,6 +23,8 @@ pub trait Layer { /// The size of the packet fn packet_size() -> u16; } + +/// Empty layer to serve as a placeholder layer for parsing Ethernet packet #[derive(Debug)] pub struct EmptyLayer {} impl EmptyLayer { @@ -32,19 +35,23 @@ impl EmptyLayer { impl Layer for EmptyLayer { type Input = EmptyLayer; + + /// Shouldn't be used fn parse(_upper: EmptyLayer, _bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized, { - (Self {}, 0, LayerType::END) + panic!("Don't use this function"); } + /// Shouldn't be used fn serialize(&self) -> Vec { - vec![] + panic!("Don't use this function"); } + /// Shouldn't be used fn packet_size() -> u16 { - 0 + panic!("Don't use this function"); } } @@ -53,21 +60,32 @@ pub trait HasChecksum { fn calculate_checksum(&mut self, data: &[u8]); } +/// A layer type to indicate what the next data holds #[derive(Debug, PartialEq, Eq)] pub enum LayerType { + /// EthernetPacket ETH, + /// IPPacket IP, + /// ARPPacket ARP, + /// UDPPacket UDP, + /// ICMPPacket ICMP, + /// DHCPPacket DHCP, + /// TCPPacket TCP, + /// An error occured ERR, - END, // the default layer type + /// No more data (but not error) + END, } +// todo: reduce size of enum /// Wrapper type to allow me to return a generic -/// todo: reduce size of enum +/// Is both a type (what is the kind) and a packet (something that implements Layer) #[derive(Debug)] pub enum PacketData { ETH(EthernetPacket), @@ -82,48 +100,56 @@ pub enum PacketData { } impl PacketData { + /// Forcefully unwrap the packet as EthernetPacket pub fn unwrap_eth(self) -> EthernetPacket { match self { PacketData::ETH(val) => val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } + /// Forcefully unwrap the packet as IPPacket pub fn unwrap_ip(self) -> IPPacket { match self { PacketData::IP(val) => val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } + /// Forcefully unwrap the packet as ARPPacket pub fn unwrap_arp(self) -> ArpPacket { match self { PacketData::ARP(val) => val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } + /// Forcefully unwrap the packet as UDPPacket pub fn unwrap_udp(self) -> UDPPacket { match self { PacketData::UDP(val) => val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } + /// Forcefully unwrap the packet as DHCPPacket pub fn unwrap_dhcp(self) -> DHCPPacket { match self { PacketData::DHCP(val) => val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } + /// Forcefully unwrap something that is undefined as an empty layer pub fn unwrap_undef(self) -> EmptyLayer { match self { PacketData::UNDEF(val) => val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } + /// Forcefully unwrap the packet as TCPPacket pub fn unwrap_tcp(self) -> TCPPacket { match self { PacketData::TCP(val) => val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } + /// Get the type of the PacketData by decomposing into LayerType pub fn get_type(&self) -> LayerType { match self { PacketData::ETH(_) => LayerType::ETH, @@ -139,61 +165,86 @@ impl PacketData { } } +/// Put together the parse functions to fully parse a packet +/// - returns the amount of data parsed and the end packet pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { + // Start the state machine with an undefined packet and a next_type of ethernet packet let mut i = 0; let mut last_layer = PacketData::UNDEF(EmptyLayer::new()); let mut next_type = LayerType::ETH; loop { + // Iterate matching on the state match next_type { LayerType::ETH => { + // ethernet state - unwrap the last layer as undefined let last_layer_data = last_layer.unwrap_undef(); + // parse the data (starting from i) let (eth_layer, size, network_layer_type) = EthernetPacket::parse(last_layer_data, &packet[i..]); + // save the ethernet packet, increment i, and set the next type to be what the Ethernet Packet dictates last_layer = PacketData::ETH(eth_layer); i += size; next_type = network_layer_type; } LayerType::IP => { + // IP state - unwrap the last layer as ethernet let last_layer_data = last_layer.unwrap_eth(); + // parse the data (starting from i) let (ip_layer, size, transport_layer_type) = IPPacket::parse(last_layer_data, &packet[i..]); + // save the IP packet, increment i, and set the next type to be what the IP Packet dictates last_layer = PacketData::IP(ip_layer); i += size; next_type = transport_layer_type; } LayerType::ARP => { + // ARP state - unwrap the last layer as ethernet let last_layer_data = last_layer.unwrap_eth(); + // parse the data (starting from i) let (arp_layer, size, transport_layer_type) = ArpPacket::parse(last_layer_data, &packet[i..]); + // save the ARP packet, increment i, and set the next type to be what the ARP Packet dictates last_layer = PacketData::ARP(arp_layer); i += size; next_type = transport_layer_type; } LayerType::UDP => { + // UDP state - unwrap the last layer as IP let last_layer_data = last_layer.unwrap_ip(); + // parse the data (starting from i) let (udp_layer, size, application_layer_type) = UDPPacket::parse(last_layer_data, &packet[i..]); + // save the UDP packet, increment i, and set the next type to be what the UDP Packet dictates last_layer = PacketData::UDP(udp_layer); i += size; next_type = application_layer_type; } LayerType::ICMP => { + // ICMP is unimplemented. Return an undefined packet with 0 data parsed return (0, PacketData::UNDEF(EmptyLayer::new())); } LayerType::DHCP => { + // DHCP state - unwrap the last layer as UDP let last_layer_data = last_layer.unwrap_udp(); + // parse the data (starting from i) let (dhcp_layer, size, empty_type) = DHCPPacket::parse(last_layer_data, &packet[i..]); + // save the DHCP packet, increment i, and set the next type to be what the DHCP Packet dictates last_layer = PacketData::DHCP(dhcp_layer); i += size; next_type = empty_type; } LayerType::TCP => { + // TCP state - unwrap the last layer as IP let last_layer_data = last_layer.unwrap_ip(); + // parse the data (starting from i) let (tcp_layer, size, empty_type) = TCPPacket::parse(last_layer_data, &packet[i..]); + // save the TCP packet, increment i, and set the next type to be what the TCP Packet dictates last_layer = PacketData::TCP(tcp_layer); i += size; next_type = empty_type; } LayerType::ERR => { + // Got an error so return an error packet with 0 data parsed return (0, PacketData::ERR(EmptyLayer::new())); } LayerType::END => { + // Reached a safe end state for the state machine. Return the last layer and how much data we parsed return (i, last_layer); } } diff --git a/kernel/src/network/mod.rs b/kernel/src/network/mod.rs index 6aec4d6..95d3e60 100644 --- a/kernel/src/network/mod.rs +++ b/kernel/src/network/mod.rs @@ -1,21 +1,25 @@ -pub mod arp; -pub mod bytefield; -pub mod command_register; -pub mod devices; -pub mod dhcp; -pub mod ethernet; +// Outward facing modules pub mod init; -pub mod ip; -pub mod layer; -pub mod raw_socket; pub mod rtl8139; pub mod socket; -pub mod tcp; +pub mod errors; + +// Internal workings +mod constants; +mod arp; +mod bytefield; +mod command_register; +mod devices; +mod dhcp; +mod ethernet; +mod ip; +mod layer; +mod raw_socket; +mod tcp; mod tcp_session; -pub mod udp; -// todo: remove pub until things break... +mod udp; mod arp_table; -pub mod constants; mod netsync; mod processing; mod raw_array; +mod network_query; \ No newline at end of file diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs index 79614e6..46818de 100644 --- a/kernel/src/network/netsync.rs +++ b/kernel/src/network/netsync.rs @@ -4,15 +4,21 @@ use super::rtl8139::{NetworkConfig, RTL8139}; struct InterruptGuard {} +/// A guard for disabling interrupts, accessing interrupt-sensitive locks, and then re-enabling interrupts +/// - in progress - doesn't function correctly yet pub struct NetworkInterruptsGuard<'a> { + /// Internal mutex guard protected by the network interrupts guard data: MutexGuard<'a, Option>, } + impl NetworkInterruptsGuard<'_> { + /// Get the internals mutable pub fn get_mut(&mut self) -> Option<&mut RTL8139> { return self.data.as_mut(); } + /// Get a reference to the internals not-mutable pub fn get_ref(&self) -> Option<&RTL8139> { return self.data.as_ref(); } @@ -27,36 +33,49 @@ impl Drop for NetworkInterruptsGuard<'_> { } } +/// A driver that is "safe" to access without deadlocking with interrupt handler pub struct SafeRTL8139 { data: spin::Mutex>, pub config: spin::Mutex, } impl SafeRTL8139 { + /// Construct a new safe RTL to protect a real RTL pub fn new(data: spin::Mutex>, config: spin::Mutex) -> Self { Self { data, config } } + /// Get the internals without deadlocking with interrupt handler by disabling interrupts with the network interrupts guard pub fn lock(&self) -> NetworkInterruptsGuard { // disable_network_interrupts(); return NetworkInterruptsGuard { data: self.data.lock() }; } - pub fn lock_no_disable(&self) -> MutexGuard> { + + /// Get the internals without disabling interrupts --> this is unsafe + /// Should only be used in interrupt handler + pub unsafe fn lock_no_disable(&self) -> MutexGuard> { return self.data.lock(); } } +/// A counter for how many times interrupts were disabled pub struct InterruptCounter { - pub data: u32, + data: u32, } impl InterruptCounter { + /// Create a new interrupt counter object with initial value 0 + pub const fn new() -> Self { + InterruptCounter { data: 0 } + } + /// Get the value pub fn get(&self) -> u32 { self.data } + /// Increment the value pub fn inc(&mut self) { self.data += 1; } - + /// Decrement the value pub fn dec(&mut self) { self.data -= 1; } diff --git a/kernel/src/network/network_query.rs b/kernel/src/network/network_query.rs new file mode 100644 index 0000000..4a811f7 --- /dev/null +++ b/kernel/src/network/network_query.rs @@ -0,0 +1,84 @@ +use futures_util::StreamExt; + +use super::{ + arp::ArpPacket, + constants::{ARP_PORT, BROADCAST_MAC}, + ethernet::{EthType, EthernetPacket}, + layer::{Layer, LayerType}, + raw_socket::RawSocket, + rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, +}; + +// todo: move DHCP query here? + +/// A module for querying things with the network stack +pub struct NetworkQuery {} + +impl NetworkQuery { + /// Get a mac address from an ip address + /// wait_timeout is how many seconds until giving up + pub async fn get_mac_from_ip(wait_timeout: u32, ip: u32) -> Option { + // Acquire the driver + disable_network_interrupts(); + let mut rtl_dev_info_locked = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); + + // iterate through entries in the arp table + for entry in rtl_dev_info.arp_table.iter() { + // todo: check for expired arps + // if entry matches, we can return from the cache + if entry.ip == ip { + return Some(entry.mac); + } + } + + // Otherwise, we send an arp packet and query the IP + let eth_layer = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::Arp); + let arp_layer = ArpPacket::gen(eth_layer, rtl_dev_info.my_ip_address.unwrap(), ip, true); + rtl_dev_info.send_packet(&arp_layer.serialize()); + + // and release the driver + drop(rtl_dev_info_locked); + enable_network_interrupts(); + + // wait for response by creating an ARP socket + let mut socket = RawSocket::new(ARP_PORT, 3).unwrap(); + let mut retries = 0; + loop { + // Loop on socket packets + if let Some(pkt) = socket.next().await { + // If we have a packet error, just loop again + if pkt.is_err() { + continue; + } + // Unwrap and type check the data + let pkt_data = pkt.unwrap(); + if pkt_data.get_type() != LayerType::ARP { + continue; + } + let arp_pkt = pkt_data.unwrap_arp(); + // todo: add to the ARP table + // Once we've unwrapped the packet, we can close the socket and return the sender mac + socket.close(); + return Some(arp_pkt.sender_mac.val()); + } else { + // If we timed-out + disable_network_interrupts(); + // acquire the driver again + let rtl_dev_guard = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); + // send the packet + rtl_dev_info.send_packet(&arp_layer.serialize()); + // release the driver + drop(rtl_dev_guard); + enable_network_interrupts(); + } + // Count retries and if we exceed the limit, we die + retries += 1; + if retries == wait_timeout * 6 { + socket.close(); + return None; + } + } + } +} diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index ef6ab6a..2800863 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -7,14 +7,14 @@ use crate::{ network::{ arp::ArpPacket, arp_table::ArpEntry, - constants::{ARP_PORT, BROADCAST_ADDR, TCP_ACK, TCP_FIN, TCP_SYN}, + constants::{ARP_PORT, BROADCAST_ADDR, TCP_SYN}, ethernet::{EthType, EthernetPacket}, ip::{IPPacket, Protocol}, layer::{Layer, LayerType, PacketData}, - raw_socket::{wake_sockets, NetworkErrors}, + raw_socket::wake_sockets, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, tcp::TCPPacket, - tcp_session::{SessionAction, TCPSession}, + tcp_session::{SessionAction, TCPSession}, errors::NetworkErrors, }, println, }; @@ -26,15 +26,22 @@ use core::{ use super::layer::full_parse; +/// A waker for waking the pending packet stream static PROCESS_VEC_WAKER: AtomicWaker = AtomicWaker::new(); +/// An array queue for data to parse static PENDING_DATA: OnceCell>> = OnceCell::uninit(); -pub struct PendingProcessingStream { +/// A empty struct with behavior +/// - the idea is its a queue for the interrupt handler to enqueue vectors to represent packets +/// - we have a task that polls this stream and processes them +struct PendingProcessingStream { _private: (), } impl PendingProcessingStream { - pub fn new() -> Self { + /// Create a new pending process stream + fn new() -> Self { + // Initialize the pending data array queue with max size 100 PENDING_DATA .try_init_once(|| ArrayQueue::new(100)) .expect("PendingProcessingStream::new should only be called once"); @@ -43,24 +50,34 @@ impl PendingProcessingStream { } impl Stream for PendingProcessingStream { + // Output tokens for the polling type Item = Vec; + /// Get the next vector of data fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll>> { + // Get the queue let queue = PENDING_DATA.try_get().expect("not initialized"); + // Get the next entry and register to be woken up let data = queue.pop(); PROCESS_VEC_WAKER.register(cx.waker()); match data { + // If we got some data -> return it Some(pkt_data) => Poll::Ready(Some(pkt_data)), + // Otherwise sleep because queue is empty None => Poll::Pending, } } } +/// A internal function of the module to append to the queue of potential packets pub(crate) fn add_pkt_data(data: Vec) { + // Try to get the array queue if let Ok(queue) = PENDING_DATA.try_get() { + // And push data if queue.push(data).is_err() { println!("[WARN] packet queue full; dropping packet"); } else { + // If we pushed data, wake the processing thing PROCESS_VEC_WAKER.wake(); } } else { @@ -68,16 +85,22 @@ pub(crate) fn add_pkt_data(data: Vec) { } } +/// Start the processing of packets +/// - this function never terminates until we receive None from the raw packet stream (which won't ever happen) pub async fn init_packet_processing() { - let mut raw_packets = PendingProcessingStream::new(); + // Create the stream + let mut raw_packets = PendingProcessingStream::new(); + // Get the next piece of data the interrupt handler drew from the card while let Some(pkt_data) = raw_packets.next().await { + // Parse it let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); if amount_parsed_and_pkt.1.get_type() == LayerType::ERR || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP { // Don't deal with unrecognized packets (will fail assert) continue; } + // Assert we had a proper amount of data processed without erroring out assert!(amount_parsed_and_pkt.0 == pkt_data.len() || pkt_data.len() < 64); - // Try to get the device info + // Get the driver configuration disable_network_interrupts(); let mut net_dev = NET_INFO.lock(); // Get the device fields @@ -85,133 +108,136 @@ pub async fn init_packet_processing() { let mut rtl_dev_config = NET_INFO.config.lock(); match amount_parsed_and_pkt.1 { PacketData::ARP(arp) => { - // todo: also check for broadcast + // todo: also check for broadcast, and expire from arp table if arp.recp_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { + // WE GOT A REQUEST, so create a response let eth_layer = EthernetPacket::gen(arp.sender_mac.val(), rtl_dev_info.mac_address.unwrap(), EthType::Arp); - let arp_layer = ArpPacket::gen( - eth_layer, - rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR), - arp.sender_ip.val(), - false, - ); + let ip_address = rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR); + let arp_layer = ArpPacket::gen(eth_layer, ip_address, arp.sender_ip.val(), false); let arp_pkt = arp_layer.serialize(); + // and send it rtl_dev_info.send_packet(&arp_pkt); } else { - // todo: expire from arp table? + // WE GOT A RESPONSE, so save it in the arp table rtl_dev_info.arp_table.push(ArpEntry { mac: arp.sender_mac.val(), ip: arp.sender_ip.val(), expires: 0, }); + // If there was some process listening on the ARP "port" -> then we have to upstream the packet if rtl_dev_info.open_ports.contains(&ARP_PORT) { - // if we are listening on the port, try to insert it into the map + // if we are listening on the port, try to insert initialize it into the map if !rtl_dev_info.ports.contains_key(&ARP_PORT) { rtl_dev_info.ports.insert(ARP_PORT, VecDeque::new()); } + // Push the packet into the port structure and wake the port rtl_dev_info.ports.get_mut(&ARP_PORT).unwrap().push_back(Ok(PacketData::ARP(arp))); wake_sockets(ARP_PORT); } } } PacketData::DHCP(dhcp) => { + // DHCP packet let dst_port = dhcp.udp.dest_port.val() as u64; println!("[HANDLER] Found DHCP packet"); + // If we are listening on the DHCP port if rtl_dev_info.open_ports.contains(&dst_port) { println!("[HANDLER] Port {} is open", dst_port); - // if we are listening on the port, try to insert it into the map + // Try to initialize the port data structure if !rtl_dev_info.ports.contains_key(&dst_port) { rtl_dev_info.ports.insert(dst_port, VecDeque::new()); } + // Push back the dhcp packet and wake the port rtl_dev_info.ports.get_mut(&dst_port).unwrap().push_back(Ok(PacketData::DHCP(dhcp))); wake_sockets(dst_port); } } PacketData::UDP(udp) => { + // UDP packet let dst_port = udp.dest_port.val() as u64; + // If we are listening on the port if rtl_dev_info.open_ports.contains(&dst_port) { - // if we are listening on the port, try to insert it into the map + // Try to initialize the port data structure if !rtl_dev_info.ports.contains_key(&dst_port) { rtl_dev_info.ports.insert(dst_port, VecDeque::new()); } + // Push back the UDP packet and wake the port rtl_dev_info.ports.get_mut(&dst_port).unwrap().push_back(Ok(PacketData::UDP(udp))); wake_sockets(dst_port); } } PacketData::TCP(tcp) => { + // TCP Packet let dst_port = tcp.dest_port.val() as u64; if !rtl_dev_info.open_ports.contains(&dst_port) { + // If we aren't listening on the port, throw the packet out continue; } - // if we are listening on the port, try to insert it into the map + // Try to initialize the port structure if !rtl_dev_info.ports.contains_key(&dst_port) { rtl_dev_info.ports.insert(dst_port, VecDeque::new()); } + // Create the session key let session_key = TCPSession::gen_session_key(tcp.ip.source_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); - // Open up a session + // Open up a session if it doesn't exist if !rtl_dev_config.tcp_sessions.contains_key(&session_key) { if (tcp.get_flags() & TCP_SYN) == 0 { - // Ignore requests when there is no request for syncing + // Don't create new session when there is no request for syncing enable_network_interrupts(); return; } - // Compact the first packet we receive as the session creation + // Compact the first packet we receive as a template for future requests let eth_layer = EthernetPacket::gen(tcp.ip.eth.src_mac.val(), tcp.ip.eth.dest_mac.val(), EthType::IPv4); - let ip_layer = IPPacket::gen( - eth_layer, - 0, // leaving size undefined for the template - Protocol::TCP, - tcp.ip.destination_ip.val(), - tcp.ip.source_ip.val(), - ); + // ! IMP: leaving size undefined for the template + let ip_layer = IPPacket::gen(eth_layer, 0, Protocol::TCP, tcp.ip.destination_ip.val(), tcp.ip.source_ip.val()); let tcp_layer = TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); + // Create a session with the template let session = TCPSession::new(tcp_layer, tcp.ip.source_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); - // Lets push to the port -- we are listening then we need to create a new session + // Lets upstream our recevied packet to the listening port (and wake it) + // - This is used to split the socket into an "Established" session socket and a "Listening" socket rtl_dev_info .ports .get_mut(&dst_port) .unwrap() .push_back(Ok(PacketData::TCP(tcp.clone()))); wake_sockets(dst_port); + // Finally insert our tcp session rtl_dev_config.tcp_sessions.insert(session.session_key(), session); } - // todo: SYN COOKIES? - // todo: Should I have a buffer limit? -- - // ? I think maybe no, because upstream data should be prioritized -- - // ? Otherwise I might not have enough processing speed + // Get the tcp session let tcp_session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); - // Generate an acknowledgement and determine if the tcp packet has data + // Generate an acknowledgement via the session's process_recv function let ack_pkt = tcp_session.process_recv(&tcp); if let Some(response) = ack_pkt.0 { - // If we got a response packet to send back - // todo: what happens if our response is dropped... we need to re-ack? - // generally if no ack is receeived, the host will send another transmission - // ALSO we don't know if our ack was received or not, so we just wait for another transmission - // we could also get duplicate data tho? so we need to identify this case + // If we got a response packet to send back, send it + // If no ack is receeived, the host will send another transmission for us to respond to rtl_dev_info.send_packet(&response.serialize()); } - if ack_pkt.1 != SessionAction::Drop { - // Push the packet to the raw socket --> it will handle its data, if present - if rtl_dev_info.open_ports.contains(&session_key) { - // if we are listening on the session, try to insert it into the map - if !rtl_dev_info.ports.contains_key(&session_key) { - rtl_dev_info.ports.insert(session_key, VecDeque::new()); - } - // Insert the data and wake the socket - let res = if ack_pkt.1 == SessionAction::PushUpstream { - Ok(PacketData::TCP(tcp)) - } else if ack_pkt.1 == SessionAction::EndOfStream { - Err(NetworkErrors::ClosedSocket) - } else { - unreachable!(); - }; - rtl_dev_info.ports.get_mut(&session_key).unwrap().push_back(res); - wake_sockets(session_key); + // Interpret the action from the process_recv function + if ack_pkt.1 != SessionAction::Drop && rtl_dev_info.open_ports.contains(&session_key) { + // if we are listening on the session, try to init the packet queue on that end + if !rtl_dev_info.ports.contains_key(&session_key) { + rtl_dev_info.ports.insert(session_key, VecDeque::new()); } + let res = if ack_pkt.1 == SessionAction::PushUpstream { + // If the action is upstream, then we push the packet to the raw socket to handle it's data (if present) + Ok(PacketData::TCP(tcp)) + } else if ack_pkt.1 == SessionAction::EndOfStream { + // If the action is end of stream, we push the sentinel to indicate the stream has finished + Err(NetworkErrors::ClosedSocket) + } else { + // We shouldn't reach this unless we added another action and didn't handle it + unreachable!(); + }; + // Push back the packet or end-of-stream token to the session socket -> and then wake the socket + rtl_dev_info.ports.get_mut(&session_key).unwrap().push_back(res); + wake_sockets(session_key); } } - _ => {} // ignore others + _ => {} // ignore other packets } + // Release the guard and enable interrupts drop(net_dev); enable_network_interrupts(); } diff --git a/kernel/src/network/raw_array.rs b/kernel/src/network/raw_array.rs index d6c93dc..353183a 100644 --- a/kernel/src/network/raw_array.rs +++ b/kernel/src/network/raw_array.rs @@ -1,67 +1,33 @@ use alloc::vec; use alloc::vec::Vec; -use core::ops::Index; - -// Leaving this here unless we change an implementation that necessiates a differnt type of array -/*pub struct RawArray { - start: *const u8 -} - -impl RawArray { - /// An infinite array beginning at "start" - pub fn new(start: *const u8) -> Self { - RawArray { - start - } - } - - // Ignore values - pub fn shift_amount(&mut self, amount: usize) -> () { - self.start = unsafe { self.start.add(amount) }; - } - - // Move the array forward, "consuming" those values - pub fn trim(&mut self, amount: usize) -> Vec { - let mut res = vec![]; - for _ in 0..amount { - unsafe { - res.push(*self.start); - self.start = self.start.add(1); - } - } - res - } - -} - -impl Index for RawArray { - type Output = u8; - /// Index into the infinite array using raw pointers - fn index<'a>(&'a self, i: usize) -> &u8 { - unsafe { &(*self.start.add(i)) } - } -}*/ +/// An array to represent the buffer of the RTL8139 +/// (will wrap it's array accessess as required by the data-sheet) pub struct WrappingRawArray { + /// The starting address of the buffer start: *const u8, + /// The position of the reading pos: usize, + /// The max size of the buffer (helps to prevent going past the end raw_address) size: usize, } impl WrappingRawArray { - /// An infinite array beginning at "start" and wrapping after size bytes + /// An "infinite" array beginning at "start" and wrapping after size bytes pub fn new(start: *const u8, size: usize) -> Self { WrappingRawArray { start, pos: 0, size } } - // Ignore values + /// Ignore values by amount pub fn shift_amount(&mut self, amount: usize) { self.pos = (self.pos + amount) % self.size; } - // Move the array forward, "consuming" those values + /// Move the array forward, "consuming" those values pub fn trim(&mut self, amount: usize) -> Vec { + // Create a result vector let mut res = vec![]; + // Move to the starting position let mut tmp_start = unsafe { self.start.add(self.pos) }; for _ in 0..amount { // append the byte and move tmp_start forward @@ -71,21 +37,15 @@ impl WrappingRawArray { } // also increment the position self.pos += 1; - // if the position is equal to size (we are outside the buffer) + // if the position is equal to size (we are at the edge of the buffer) if self.pos == self.size { - // so we reset to the beginning + // so we reset to the beginning of the buffer self.pos = 0; tmp_start = self.start; } } + // return the result res } } -impl Index for WrappingRawArray { - type Output = u8; - /// Index into the infinite array using raw pointers - fn index(&self, i: usize) -> &u8 { - unsafe { &(*self.start.add(i % self.size)) } - } -} diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index ca0f651..d366024 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -1,46 +1,42 @@ use futures_util::{task::AtomicWaker, Stream}; use hashbrown::HashMap; -use crate::{ - println, - task::timeout::{cancel_timeout, register_timeout, TimeoutID}, -}; +use crate::task::timeout::{cancel_timeout, register_timeout, TimeoutID}; use super::{ layer::PacketData, - rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, + rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, errors::NetworkErrors, }; use core::task::{Context, Poll}; use lazy_static::lazy_static; lazy_static! { + /// A data structure to contain the wakers to awaken a sleeping network port pub static ref NEW_PACKET_WAKER: spin::Mutex> = spin::Mutex::new(HashMap::new()); } -#[derive(Debug, PartialEq, Eq)] -pub enum NetworkErrors { - PortInUse, - NoAvailablePort, - NonexistentHost, - BadSocketState, - FeatureNotAvailableYet, - Timeout, - /// This is a special network error for when our TCP stream has closed - ClosedSocket, -} - +/// A raw socket object to poll for packets on the network stack pub struct RawSocket { + /// The port that raw socket owns port: u64, + /// How many epochs (of the timer interrupt) until we spontaneously reawaken our processing timeout_in_epochs: u16, + /// If the timeout is registered on this raw socket timeout_active: bool, + /// The id of the timeout we registered (for cancellation) timeout_id: TimeoutID, } impl RawSocket { + /// Construct a new instance of the raw socket with port + /// - will register a timeout after timeout_in_epochs + /// - will return itself or a network error if cannot bind to the port pub fn new(port: u64, timeout_in_epochs: u16) -> Result { + // Acquire the driver guard disable_network_interrupts(); let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // Check if the port is in use if rtl_dev_info.open_ports.contains(&port) { enable_network_interrupts(); @@ -50,7 +46,12 @@ impl RawSocket { rtl_dev_info.open_ports.insert(port); // and allocate a waker NEW_PACKET_WAKER.lock().insert(port, AtomicWaker::new()); + + // Release the driver guard + drop(rtl_dev_info_guard); enable_network_interrupts(); + + // Return the raw socket's initial state Ok(RawSocket { port, timeout_in_epochs, @@ -59,7 +60,9 @@ impl RawSocket { }) } + /// Internal function to query for a packet fn try_get_packet_inner(&self) -> Option> { + // Acquire the driver and try pop from the queue let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); match rtl_dev_info.ports.get_mut(&self.port) { @@ -68,13 +71,15 @@ impl RawSocket { } } + /// Close the raw socket and release the resources associated with it pub fn close(self) { + // The raw socket closes by acquiring the driver disable_network_interrupts(); let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); - // close the port so that we don't receive anymore packets + // and closing the port so that we don't receive anymore packets rtl_dev_info.open_ports.remove(&self.port); - // Remove the waker, since the port has no listeners + // Remove the waker, since the port shouldn't have anymore listeners NEW_PACKET_WAKER.lock().remove(&self.port); // Try to clear all the pending packets from the port if rtl_dev_info.ports.contains_key(&self.port) { @@ -89,20 +94,29 @@ impl RawSocket { } impl Stream for RawSocket { + /// The result of the polling type Item = Result; + /// Poll for a new packet on the raw socket (it's main purpose) fn poll_next(mut self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Try to get a packet disable_network_interrupts(); let pkt = self.try_get_packet_inner(); enable_network_interrupts(); - // listen on the port with the waker + + // Then register the port with a waker let locked_waker_map = NEW_PACKET_WAKER.lock(); locked_waker_map[&self.port].register(cx.waker()); + if pkt.is_some() { + // If we found a packet - cancel any pending timeouts cancel_timeout(self.timeout_id); self.timeout_active = false; + // And return the packet Poll::Ready(pkt) } else if self.timeout_active { + // If we couldn't get a packet, and the timeout was active + // then we must have timed-out so we reset the timeout_active and return None from polling self.timeout_active = false; Poll::Ready(None) } else { diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 8ed669a..37b1230 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -20,8 +20,8 @@ use crate::network::constants::{CAPR, CR, CR_BUFE, CR_RE, CR_TE, RX_BUFFER_SIZE, use crate::network::raw_array::WrappingRawArray; use super::constants::{INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; +use super::errors::NetworkErrors; use super::processing::add_pkt_data; -use super::raw_socket::NetworkErrors; use super::tcp_session::TCPSession; use super::{ arp_table::ArpEntry, @@ -39,6 +39,7 @@ use crate::{ // ISR_ROK|ISR_TOK|ISR_RXOVW|ISR_TER|ISR_RER // ! URGENT: start checking other statuses for errors // todo: why can I get "send" interrupts... what is the purpose +// todo: write comments for this file // FROM OS DEV // N.B. If you find your driver suddenly freezes and stops receiving interrupts and you're using kvm/qemu. Try the option -no-kvm-irqchip @@ -56,7 +57,7 @@ lazy_static! { }; } -static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = spin::Mutex::new(InterruptCounter { data: 0 }); +static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = spin::Mutex::new(InterruptCounter::new()); // Disable network interrupts (is thread safe) pub fn disable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; @@ -79,7 +80,7 @@ pub fn enable_network_interrupts() { pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) { // Try to get the device info - let mut net_dev = NET_INFO.lock_no_disable(); + let mut net_dev = unsafe { NET_INFO.lock_no_disable() }; if net_dev.is_none() { panic!("RTL_INFO is undefined!"); } diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 10557a0..1f63644 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -4,81 +4,29 @@ use alloc::vec; use alloc::vec::Vec; use futures_util::StreamExt; -use crate::{ - network::{constants::TCP_ACK, layer::LayerType, tcp::TCPPacket}, - print, println, -}; +use crate::{network::layer::LayerType, println}; use super::{ - arp::ArpPacket, - constants::{ARP_PORT, BROADCAST_MAC, TCP_PSH}, - ethernet::{self, EthType, EthernetPacket}, + constants::TCP_PSH, + errors::NetworkErrors, + ethernet::{self, EthernetPacket}, ip::{IPPacket, Protocol}, layer::{HasChecksum, Layer, PacketData}, - raw_socket::{NetworkErrors, RawSocket}, + network_query::NetworkQuery, + raw_socket::RawSocket, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, tcp_session::TCPSession, udp::UDPPacket, }; -pub struct NetworkQuery {} - -impl NetworkQuery { - pub async fn get_mac_from_ip(wait_timeout: u32, ip: u32) -> Option { - disable_network_interrupts(); - let mut rtl_dev_info_locked = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); - for entry in rtl_dev_info.arp_table.iter() { - // todo: check for expired arps - if entry.ip == ip { - return Some(entry.mac); - } - } - // send arp packet - let eth_layer = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::Arp); - let arp_layer = ArpPacket::gen(eth_layer, rtl_dev_info.my_ip_address.unwrap(), ip, true); - rtl_dev_info.send_packet(&arp_layer.serialize()); - drop(rtl_dev_info_locked); - enable_network_interrupts(); - - // wait for response - let mut socket = RawSocket::new(ARP_PORT, 3).unwrap(); - let mut timeout = 0; - loop { - if let Some(pkt) = socket.next().await { - if pkt.is_err() { - continue; - } - let pkt_data = pkt.unwrap(); - if pkt_data.get_type() != LayerType::ARP { - continue; - } - let arp_pkt = pkt_data.unwrap_arp(); - return Some(arp_pkt.sender_mac.val()); - } else { - disable_network_interrupts(); - { - let rtl_dev_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); - rtl_dev_info.send_packet(&arp_layer.serialize()); - } - enable_network_interrupts(); - } - timeout += 1; - if timeout == wait_timeout * 6 { - socket.close(); - return None; - } - } - } -} - +/// An enum to represent socket type (UDP or TCP) #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum SocketType { UDP, TCP, } +/// An enum to represent socket state (Listening, Ready to Send, or Closed) #[derive(Debug, PartialEq, Eq, Clone, Copy)] enum SocketState { Listening, @@ -86,6 +34,8 @@ enum SocketState { Closed, } +/// The socket object for communication with the network +/// Will utilize the network stack pub struct Socket { socket_type: SocketType, socket_state: SocketState, @@ -101,30 +51,48 @@ pub struct Socket { } impl Socket { - // Can't send yet -- must listen + /// Open a socket in listening mode on src_port with a timeout + /// Can be TCP or UDP + /// + /// - *Can't send yet -- must listen first* + /// - or can use connect instead pub async fn open(socket_type: SocketType, src_port: u16, wait_timeout: u16) -> Result { + // Start the chosen src port as src_port let mut chosen_src_port = src_port; + + // Acquire the network driver disable_network_interrupts(); let rtl_dev_info_locked = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_locked.get_ref().unwrap(); + + // get the src mac and address from the driver let src_mac = rtl_dev_info.mac_address.unwrap(); let src_address = rtl_dev_info.my_ip_address.unwrap(); + // If our source port is 0 if src_port == 0 { let open_ports = &rtl_dev_info.open_ports; // Linear probe to see if any ports are open for i in 1000..u16::MAX { if !open_ports.contains(&(i as u64)) { + // Once we find an open port from 1000 - 65535, we set the chosen_src_port and break chosen_src_port = i; break; } } } + // Release the driver code drop(rtl_dev_info_locked); enable_network_interrupts(); + + // If our chosen src port is 0, we have no available port error if chosen_src_port == 0 { return Err(NetworkErrors::NoAvailablePort); } + + // create a raw socket for the socket to use let raw_socket = RawSocket::new(chosen_src_port as u64, max(wait_timeout * 18, 1)); + + // If the raw_socket is good, return the socket wrapper match raw_socket { Ok(socket) => Ok(Socket { socket_type, @@ -139,45 +107,67 @@ impl Socket { wait_timeout, session_key: 0, }), + // otherwise return the error from the raw socket construction Err(err) => Err(err), } } - // Will listen for new connections and create new sessions - // UDP can only listen for one connection (and thus will return none). -- UDP is connectionless + /// Will listen for new connections and create new sessions once something arrives + /// - **UDP can only listen for one connection (and thus will return none)** + /// - **UDP must still call listen to transition into the correct state** pub async fn listen(&mut self) -> Option { if self.socket_state != SocketState::Listening { return None; // todo: this is failing silently } loop { + // Loop on packets from the raw socket if let Some(pkt_or_err) = self.raw_socket.next().await { + // If we get an error from the socket, return None if pkt_or_err.is_err() { // todo: this is failing silently return None; } + // Unwrap the packet let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::UDP && self.socket_type == SocketType::UDP { + // If we have a match (UDP <--> UDP) + // then we unwrap the udp packet and set the dest_port, dest_mac, and dest_ip let udp_pkt = pkt.unwrap_udp(); self.dest_port = udp_pkt.src_port.val(); self.dest_mac = udp_pkt.ip.eth.src_mac.val(); self.dest_ip = udp_pkt.ip.source_ip.val(); + // and transition into ready state self.socket_state = SocketState::Ready; + + // acquire the driver disable_network_interrupts(); let mut rtl_dev_info_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // re-enqueue the packet if let Some(vec) = rtl_dev_info.ports.get_mut(&(self.src_port as u64)) { vec.push_front(Ok(PacketData::UDP(udp_pkt))); } + + // release the driver + drop(rtl_dev_info_guard); enable_network_interrupts(); + // and return None return None; } else if pkt.get_type() == LayerType::TCP && self.socket_type == SocketType::TCP { + // If we have a match (TCP <--> TCP) + // Then we unwrap the TCP packet let tcp_pkt = pkt.unwrap_tcp(); + // And extract and save the dest_address and dest_port let dest_address = tcp_pkt.ip.source_ip.val(); let dest_port = tcp_pkt.src_port.val(); + // We create a session key let session_key = TCPSession::gen_session_key(dest_address, dest_port, self.src_port); + // And a raw socket let raw_socket = RawSocket::new(session_key, max(self.wait_timeout * 18, 1)).unwrap(); println!("[INFO] Spawned new TCP session"); + // And return a new socket object to be the ready socket + // the current socket never transitions out of listening return Some(Socket { socket_type: SocketType::TCP, raw_socket, @@ -191,15 +181,13 @@ impl Socket { wait_timeout: self.wait_timeout, session_key, }); - } else { - // println!("[DEBUG] not useful packet"); } - } else { - // println!("[DEBUG] Got none"); } } } + /// Connect to a foreign socket that is listening + // todo: untested pub async fn connect( socket_type: SocketType, dest_address: u32, @@ -207,32 +195,49 @@ impl Socket { src_port: u16, wait_timeout: u16, ) -> Result { + // Get the chosen port let mut chosen_src_port = src_port; + + // Acquire the driver disable_network_interrupts(); let rtl_dev_info_locked = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_locked.get_ref().unwrap(); + + // Extract src mac and src address let src_mac = rtl_dev_info.mac_address.unwrap(); let src_address = rtl_dev_info.my_ip_address.unwrap(); + + // if src port is 0 if src_port == 0 { let open_ports = &rtl_dev_info.open_ports; // Linear probe to see if any ports are open for i in 1000..u16::MAX { if !open_ports.contains(&(i as u64)) { + // once we find an open port, save it chosen_src_port = i; break; } } } - let dest_mac = NetworkQuery::get_mac_from_ip(10, dest_address).await; + + // Release the driver drop(rtl_dev_info_locked); enable_network_interrupts(); + + // Also query for a destination mac address + let dest_mac = NetworkQuery::get_mac_from_ip(10, dest_address).await; + + // If destination mac is none, we have a non-existent host error if dest_mac.is_none() { return Err(NetworkErrors::NonexistentHost); } + // If source port is 0, we have no available ports to bind to if chosen_src_port == 0 { return Err(NetworkErrors::NoAvailablePort); } + // Create a raw socket to work with let raw_socket = RawSocket::new(chosen_src_port as u64, max(wait_timeout * 18, 1)); + // If the raw socket was created successfully, we return a new socket object match raw_socket { Ok(socket) => Ok(Socket { socket_type, @@ -251,18 +256,25 @@ impl Socket { } } + /// Close the socket and release the resources associated with it pub async fn close(mut self) { if self.socket_state == SocketState::Closed { // Already closed return; } if self.socket_type == SocketType::TCP { + // If tcp, do special stuff + // First we get the session stuff let mut rtl_dev_config = NET_INFO.config.lock(); + // Then we get a session key let session_key = TCPSession::gen_session_key(self.dest_ip, self.dest_port, self.src_port); if rtl_dev_config.tcp_sessions.contains_key(&session_key) { + // If we have a tcp session, create a fin_ack packet from the raw socket let session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); let fin_ack_pkt = session.close(); + // Drop the session stuff we don't need it anymore drop(rtl_dev_config); + // If our fin_ack packet creation was successful if let Ok(pkt) = fin_ack_pkt { // With 6 retries 'outer: for _ in 0..6 { @@ -280,13 +292,13 @@ impl Socket { enable_network_interrupts(); loop { - // Keep reading the stream + // Keep reading the stream until we get an error if let Err(next) = self.read_tcp(0).await { if next == NetworkErrors::ClosedSocket { // If we have a closed stream, we break out completely break 'outer; } - // Likely got a timeout - so retry + // Likely got a timeout - so retry by sending another packet break; } } @@ -296,104 +308,168 @@ impl Socket { NET_INFO.config.lock().tcp_sessions.remove(&session_key); } } + // Close the raw socket self.raw_socket.close(); } + /// Read data from the socket pub async fn read(&mut self, size: usize) -> Result, NetworkErrors> { + // Check socket state if self.socket_state != SocketState::Ready { return Err(NetworkErrors::BadSocketState); } + // Match socket type and go to appropriate internal function match self.socket_type { SocketType::UDP => self.read_udp(size).await, SocketType::TCP => self.read_tcp(size).await, } } + /// Internal function for reading as a UDP socket async fn read_udp(&mut self, size: usize) -> Result, NetworkErrors> { loop { + // Spin until we get some packet if let Some(pkt_or_err) = self.raw_socket.next().await { + // If we poll an error, we pass it to the read as an error if let Err(err) = pkt_or_err { return Err(err); } let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::UDP { + // If we have a matching UDP packet, we pass the data to the read result let udp_pkt = pkt.unwrap_udp(); return Ok(udp_pkt.data); } } else { + // We got a timeout and we return for UDP when this happens return Err(NetworkErrors::Timeout); } } } + /// Internal function for reading as a TCP socket async fn read_tcp(&mut self, size: usize) -> Result, NetworkErrors> { loop { + // Create a result vector let mut res_vec = vec![]; + // Loop until we get a packet if let Some(pkt_or_err) = self.raw_socket.next().await { + // If the packet is an error, read_tcp returns an error if let Err(err) = pkt_or_err { return Err(err); } let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::TCP { + // If our packet matches our intended type, save the packet into our buffer let mut tcp_pkt = pkt.unwrap_tcp(); res_vec.append(&mut tcp_pkt.data); + // If our packet has PSH, we ignore the size request and immediately push to the application if tcp_pkt.get_flags() & TCP_PSH != 0 { return Ok(res_vec); } } + // Our size has exceeded the request, so we can return the resulting data if size <= res_vec.len() { return Ok(res_vec); } } else { + // Our socket timed-out reading, so we return the timeout error return Err(NetworkErrors::Timeout); } } } + /// Write data to the socket pub async fn write(&mut self, data: &mut Vec) -> Result { + // Check socket state + if self.socket_state != SocketState::Ready { + return Err(NetworkErrors::BadSocketState); + } + // Match socket type and go to appropriate internal function match self.socket_type { SocketType::UDP => self.write_udp(data), SocketType::TCP => self.write_tcp(data).await, } } + /// Internal function for writing to the UDP socket + fn write_udp(&self, data: &mut Vec) -> Result { + // Check states + assert!(self.socket_type == SocketType::UDP); + if self.socket_state != SocketState::Ready { + return Err(NetworkErrors::BadSocketState); + } + // Build a UDP packet + let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); + let udp_size = UDPPacket::packet_size() + data.len() as u16; + let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::UDP, self.src_address, self.dest_ip); + let data_len = data.len(); + let mut udp_layer = UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); + udp_layer.data = data.to_vec(); // todo: split the data and return the amount actually written! + + // Calculate checksum + let data_2_send = udp_layer.serialize(); + let start_udp = data_2_send.len() - (UDPPacket::packet_size() as usize + data_len); + let start_ip = start_udp - (IPPacket::packet_size() as usize); + udp_layer.ip.calculate_checksum(&data_2_send[start_ip..start_udp]); + udp_layer.calculate_checksum(&data_2_send[start_udp..]); + + // Serialize + let data_2_send_final = udp_layer.serialize(); + + // Send the packet + disable_network_interrupts(); + NET_INFO.lock().get_ref().unwrap().send_packet(&data_2_send_final); + enable_network_interrupts(); + + // Return how much was written + Ok(data_len as u16) + } + + /// Internal function for writing tcp async fn write_tcp(&mut self, data: &mut Vec) -> Result { + // Check states assert!(self.socket_type == SocketType::TCP); if self.socket_state != SocketState::Ready { return Err(NetworkErrors::BadSocketState); } - // Set up to receive ack + // Set up to receive ack by processing the send on the data let mut tcp_session_guard = NET_INFO.config.lock(); let tcp_session = tcp_session_guard.tcp_sessions.get_mut(&self.session_key).unwrap(); + // Message pkt is our present to send to our server let message_pkt = tcp_session.process_send(data); + // Do error checking if let Err(err) = message_pkt { return Err(err); } - let message = message_pkt.unwrap().serialize(); // todo: error check + // Unwrap the message + let message = message_pkt.unwrap().serialize(); drop(tcp_session_guard); - // Wait for the ack + // Wait for the ack -- 20 retries for retries in 1..21 { - // Send packet + // Acquire the driver disable_network_interrupts(); let mut rtl_dev_info_locked = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); + // Send the packet rtl_dev_info.send_packet(&message); + // Release the driver drop(rtl_dev_info_locked); enable_network_interrupts(); // Get next packet or timeout if let Some(pkt_or_err) = self.raw_socket.next().await { + // If our packet is an error, we return it immediately if let Err(err) = pkt_or_err { return Err(err); } let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::TCP { - // todo: better notification system for raw_socket (alternative to .next)... - // let pkt_data = pkt.unwrap_tcp(); - // Check the acknowledgement to make sure everything is acked + // We got a TCP packet let mut tcp_session_guard = NET_INFO.config.lock(); + // Check the acknowledgement to make sure everything is acked let tcp_session = tcp_session_guard.tcp_sessions.get_mut(&self.session_key).unwrap(); if tcp_session.everything_acked() { // if so break out of the timeout loop @@ -409,27 +485,4 @@ impl Socket { // Return ok Ok(data.len() as u16) } - - fn write_udp(&self, data: &mut Vec) -> Result { - assert!(self.socket_type == SocketType::UDP); - if self.socket_state != SocketState::Ready { - return Err(NetworkErrors::BadSocketState); - } - let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); - let udp_size = UDPPacket::packet_size() + data.len() as u16; - let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::UDP, self.src_address, self.dest_ip); - let data_len = data.len(); - let mut udp_layer = UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); - udp_layer.data = data.to_vec(); // todo: split the data and return the amount actually written! - let data_2_send = udp_layer.serialize(); - let start_udp = data_2_send.len() - (UDPPacket::packet_size() as usize + data_len); - let start_ip = start_udp - (IPPacket::packet_size() as usize); - udp_layer.ip.calculate_checksum(&data_2_send[start_ip..start_udp]); - udp_layer.calculate_checksum(&data_2_send[start_udp..]); - let data_2_send_final = udp_layer.serialize(); - disable_network_interrupts(); - NET_INFO.lock().get_ref().unwrap().send_packet(&data_2_send_final); - enable_network_interrupts(); - Ok(data_len as u16) - } } diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index 6b7f3f2..c99ecd2 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -7,22 +7,35 @@ use super::{ layer::{HasChecksum, Layer, LayerType}, }; +/// A TCP packet, implements Layer and HasChecksum (20 bytes) #[derive(Debug, Clone)] pub struct TCPPacket { + /// The parent packet pub ip: IPPacket, - pub src_port: Bytefield16, // 2 bytes - pub dest_port: Bytefield16, // 2 bytes - pub seq_num: Bytefield32, // 4 bytes - pub ack_num: Bytefield32, // 4 bytes - pub flags: Bytefield16, // 2 bytes - pub sliding_window: Bytefield16, // 2 bytes - pub checksum: Bytefield16, // 2 bytes - pub urgent: Bytefield16, // +2 more, 20 bytes up until here + /// The source port + pub src_port: Bytefield16, + /// The destination port + pub dest_port: Bytefield16, + /// The sequence number of the TCP packet + pub seq_num: Bytefield32, + /// The acknowledgemnet number of the TCP packet + pub ack_num: Bytefield32, + /// The flags of the TCP packet + pub flags: Bytefield16, + /// The sliding window, determines how much congestion control is performed by the receiver + pub sliding_window: Bytefield16, + /// The checksum to prevent corruption + pub checksum: Bytefield16, + /// If the data is urgent, this is the urgent pointer to the data that is urgent + pub urgent: Bytefield16, + /// Options (unused) pub options: Vec, + /// The data of the TCP packet pub data: Vec, } impl TCPPacket { + /// Create an empty packet with all 0s pub fn new() -> Self { TCPPacket { ip: IPPacket::new(), @@ -39,66 +52,77 @@ impl TCPPacket { } } - pub fn gen(ip_packet: IPPacket, src_port: u16, dest_port: u16) -> Self { + /// Generate a TCP packet with + /// - ip_layer: is the ip layer associated with the packet + /// - src_port: the source port (open on this machine) + /// - dest_port: the destination port (open on the server) + pub fn gen(ip_layer: IPPacket, src_port: u16, dest_port: u16) -> Self { TCPPacket { - ip: ip_packet, + ip: ip_layer, src_port: Bytefield16::new(src_port), dest_port: Bytefield16::new(dest_port), seq_num: Bytefield32::new(0), ack_num: Bytefield32::new(0), - flags: Bytefield16::new(5 << 12), // a value of 5 in the header_offset (5*4 = 20 bits -> b/c no options) - sliding_window: Bytefield16::new(u16::MAX), // sliding window :(, pain to implement... we can just allow unlimited data but thats insecure... Also we would like to keep track of how much we are allowed to send + // a value of 5 in the header_offset (5*4 = 20 bits -> b/c no options) + flags: Bytefield16::new(5 << 12), + // sliding window :(, pain to implement... + // we can just allow unlimited data... Also we would like to keep track of how much we are allowed to send + sliding_window: Bytefield16::new(u16::MAX), checksum: Bytefield16::new(0), urgent: Bytefield16::new(0), + // we never provide options because we basic options: vec![], - // todo: we set the mss here manually -- fix data: vec![], } } + /// The size of the options of the TCP packet pub fn options_size() -> u16 { 0 } - // N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) - // Get the header offset + /// N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) + /// Get the header offset pub fn get_header_offset(&self) -> u8 { // convert header offset into bytes ((self.flags.val() >> 12) as u8) * 4 } - // Get the flags + /// N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) + /// Get the flags pub fn get_flags(&self) -> u8 { (self.flags.val() & 0x00FF) as u8 } - // N.B. Setting operations swap endianness because when we use val, we are in network-order - // Turn off the flags provided - pub fn turn_off_flags(&mut self, flags: u8) { - let new_flags = self.flags.swapped_endianness().val() & !(flags as u16); - self.flags = Bytefield16::new(new_flags); - } - - // Turn on the flags provided + /// N.B. Setting operations swap endianness because when we use val, we are in network-order + /// Turn on the flags provided pub fn turn_on_flags(&mut self, flags: u8) { let new_flags = self.flags.swapped_endianness().val() | (flags as u16); self.flags = Bytefield16::new(new_flags); } - // Get total length of the tcp portion + /// N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) + /// Get total length of the tcp portion pub fn total_size(&self) -> u16 { (20 + self.options.len() + self.data.len()) as u16 } } impl Layer for TCPPacket { + /// The input layer for parse type Input = IPPacket; + + /// Parsing a TCP packet requires: + /// - ip_layer: a parsed IP packet + /// - bytevec: the data to parse, with trailing but it starts where the packet must begin fn parse(ip_layer: IPPacket, bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized, { - let mut packet = TCPPacket::new(); // create an empty packet - // Read 14 bytes + // Create an empty packet + let mut packet = TCPPacket::new(); + + // Save ip packet and read 14 bytes let mut i = 0; packet.ip = ip_layer; packet.src_port = Bytefield16::read_inc(&bytevec[i..], &mut i); @@ -109,21 +133,27 @@ impl Layer for TCPPacket { packet.sliding_window = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.urgent = Bytefield16::read_inc(&bytevec[i..], &mut i); - // read remaining bytes and place them into the data buffer + // read remaining bytes of heaer and place them into the options buffer for _ in 0..(packet.get_header_offset() - 20) { packet.options.push(bytevec[i]); i += 1; } - assert!(i == packet.get_header_offset().into()); // Valid i is header_length + // Assert we read all headers + assert!(i == packet.get_header_offset().into()); + // Get the data size let data_size = packet.ip.total_length.val() - packet.get_header_offset() as u16 - IPPacket::packet_size(); + // Read everything else and save it into data buffer for _ in 0..data_size { packet.data.push(bytevec[i]); i += 1; } + // Return the packet, the amount of data consumed, and the next layer type (end) (packet, i, LayerType::END) } + /// Serialize the packet into a vector of bytes, ready to send over the network fn serialize(&self) -> alloc::vec::Vec { + // Create a vector and serialize it let mut res = vec![]; res.extend(self.ip.serialize()); res.extend(self.src_port.data); @@ -140,12 +170,15 @@ impl Layer for TCPPacket { res } + /// The amount of data that belongs to the packet-type fn packet_size() -> u16 { 20 } } impl HasChecksum for TCPPacket { + /// Calculate a checksum on the data and the packet + /// - will self mutate fn calculate_checksum(&mut self, data: &[u8]) { // Starting vars let mut sum: u32 = 0; diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index d5e76c8..ca6fba9 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -5,33 +5,48 @@ use crate::println; use super::{ bytefield::{Bytefield16, Bytefield32}, constants::{TCP_ACK, TCP_FIN, TCP_PSH, TCP_RST, TCP_SYN}, + errors::NetworkErrors, ethernet::EthernetPacket, ip::IPPacket, layer::{HasChecksum, Layer}, - raw_socket::NetworkErrors, tcp::TCPPacket, }; +/// An enum for different TCP session states #[derive(Debug, PartialEq, Eq)] pub enum TCPSessionState { + /// Waiting for SYN packets to start a connection Waiting, + /// Syncing means we are doing our three-way handshake Syncing, + /// Established is we can communicate Established, + /// Closing means we are in the process of the 4-way FIN handshake Closing, + /// Closed means no more data and any more use of this TCP session should lead to End-Of-Stream tokens Closed, } +/// An enum to be the action we take after processing a packet (receiving end only) #[derive(Debug, PartialEq, Eq)] pub enum SessionAction { + /// Push a packet upstream to a higher level of abstraction (i.e. ports) PushUpstream, + /// Drop a packet for lack of usefulness or bad state Drop, + /// End of stream, we are closing our tcp session EndOfStream, } +/// An object for managing the complexities of TCP pub struct TCPSession { + /// A internal template with prefilled fields for talking in a certain session session_template: TCPPacket, + /// The destination IP address pub dest_ip: u32, + /// The destination port pub dest_port: u16, + /// The source port pub src_port: u16, /// our seq num pub sent_data_amount: u32, @@ -39,25 +54,34 @@ pub struct TCPSession { pub sent_data_acked: u32, /// our ack num pub recv_data_amount: u32, - // pub recv_data_acked: u32, implicit how much we have acked -- this value is unknown to use + // pub recv_data_acked: u32, implicit how much we have acked -- this value is unknown to user + // todo: implement a use for window size + /// Window size is the window size of the last packet we received pub window_size: u16, /// If the user has sent fin_ack closing has_sent_fin_ack: bool, + /// If we have receieved an ack to our FIN-ACK message has_recv_ack_to_fin_ack: bool, + /// If we have sent an ack too a FIN-ACK message has_sent_ack_to_fin_ack: bool, + /// The session state of the TCPSession pub session_state: TCPSessionState, } impl TCPSession { + /// Generate a new tcp session with a template and the three fields required for a tcp session key + /// - dest_ip + dest_port + src_port = session_key pub fn new(session_template: TCPPacket, dest_ip: u32, dest_port: u16, src_port: u16) -> Self { TCPSession { session_template, + // todo: integrate with randomness to init an initial sequence number sent_data_amount: 55304, sent_data_acked: 55304, recv_data_amount: 0, dest_ip, dest_port, src_port, + // we max our window size because we want data ASAP window_size: u16::MAX, has_sent_fin_ack: false, has_recv_ack_to_fin_ack: false, @@ -66,19 +90,17 @@ impl TCPSession { } } - // todo: replace for random generated number - pub fn gen_starting_seq_num() -> u32 { - 0 - } - + /// Get the session key from this tcp session pub fn session_key(&self) -> u64 { Self::gen_session_key(self.dest_ip, self.dest_port, self.src_port) } + /// Static function for generating a session key from (dest_ip + dest_port + src_port) pub fn gen_session_key(dest_ip: u32, dest_port: u16, src_port: u16) -> u64 { (dest_ip as u64) << 32 | (dest_port as u64) << 16 | src_port as u64 } + /// If the TCPSessionState is essentially closed pub fn is_closed(&self) -> bool { self.session_state == TCPSessionState::Closed || self.has_recv_ack_to_fin_ack && self.has_sent_ack_to_fin_ack } @@ -89,11 +111,14 @@ impl TCPSession { // Transition to the closing state self.session_state = TCPSessionState::Closing; + // Set the has_sent_fin_ack flag self.has_sent_fin_ack = true; + // Clone the template and turn on proper flags let mut tcp_pkt = self.session_template.clone(); tcp_pkt.turn_on_flags(TCP_FIN | TCP_ACK); - // add the data size (maybe make this automatic?) + // Add the data size + // todo: (maybe make this automatic?) tcp_pkt.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); @@ -105,43 +130,31 @@ impl TCPSession { tcp_pkt.ip.calculate_checksum(&data[start_ip..start_tcp]); tcp_pkt.calculate_checksum(&data[start_tcp..]); + // Return the packet Ok(tcp_pkt) } - pub fn reset(&mut self) -> TCPPacket { - self.session_state = TCPSessionState::Closed; - let mut tcp_pkt = self.session_template.clone(); - tcp_pkt.turn_on_flags(TCP_RST); - - // add the data size (maybe make this automatic?) - tcp_pkt.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); - tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); - tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); - - // Calculate checksums - let data = tcp_pkt.serialize(); - let start_tcp = IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize; - let start_ip = EthernetPacket::packet_size() as usize; - tcp_pkt.ip.calculate_checksum(&data[start_ip..start_tcp]); - tcp_pkt.calculate_checksum(&data[start_tcp..]); - tcp_pkt - } - + /// A function for determining if we have acked all writes pub fn everything_acked(&self) -> bool { self.sent_data_acked == self.sent_data_amount } + /// A function for generating a packet to sent with data pub fn process_send(&mut self, data: &Vec) -> Result { + // Check session state if self.session_state != TCPSessionState::Established { return Err(NetworkErrors::BadSocketState); } + // Clone the template let mut tcp_pkt = self.session_template.clone(); + // Set flags and data tcp_pkt.turn_on_flags(TCP_ACK | TCP_PSH); tcp_pkt.data = data.to_vec(); // todo: split the data and return the amount actually written! - // add the data size (maybe make this automatic?) + // add the data size tcp_pkt.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size() + data.len() as u16); + // Set sequence number and ack_num tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); self.sent_data_amount += data.len() as u32; tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); @@ -153,14 +166,19 @@ impl TCPSession { tcp_pkt.ip.calculate_checksum(&data[start_ip..start_tcp]); tcp_pkt.calculate_checksum(&data[start_tcp..]); + // Return packet Ok(tcp_pkt) } + /// A function for processing a recevied packet and generating a response/action + /// - will return the response and session action as a tuple pub fn process_recv(&mut self, request: &TCPPacket) -> (Option, SessionAction) { + // Extract flag results let has_syn_flag = (request.get_flags() & TCP_SYN) != 0; let has_ack_flag = (request.get_flags() & TCP_ACK) != 0; let has_fin_flag = (request.get_flags() & TCP_FIN) != 0; let has_rst_flag = (request.get_flags() & TCP_RST) != 0; + // Set initial action as push upstream let mut response_action = SessionAction::PushUpstream; // todo: regularly update window size self.window_size = request.sliding_window.val(); @@ -169,29 +187,39 @@ impl TCPSession { response.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); // todo! check seq-num and ack-num before doing these state changes if has_fin_flag && self.session_state == TCPSessionState::Established { + // Transition to closing state self.session_state = TCPSessionState::Closing; } if has_rst_flag { + // transition to closed state --> dying immediately because RST packet println!("[TCP] Received RST -- dying {:?}", self.session_state); self.session_state = TCPSessionState::Closed; return (None, SessionAction::Drop); } else { + // match on session state to decide what to do match self.session_state { TCPSessionState::Waiting => { if has_syn_flag && has_ack_flag { + // If we received a SYN-ACK, we respond with a single ACK message response.turn_on_flags(TCP_ACK); if request.ack_num.val() != self.sent_data_amount + 1 { - // Incorrect ack num + // If their ack_num isn't correct, we drop the packet return (None, SessionAction::Drop); } + // Increment the seq_nums self.sent_data_amount += 1; self.sent_data_acked += 1; + // Set our ack nums by their sequence number self.recv_data_amount = request.seq_num.val() + 1; + // Set our ack_num and seq_num for our response response.ack_num = Bytefield32::new(self.recv_data_amount); response.seq_num = Bytefield32::new(self.sent_data_amount); } else if has_syn_flag { + // If we received just a SYN, send back a SYN/ACK response.turn_on_flags(TCP_SYN | TCP_ACK); + // Increment our ack num by 1 self.recv_data_amount = request.seq_num.val() + 1; + // Set our responses ack and seq num response.ack_num = Bytefield32::new(self.recv_data_amount); response.seq_num = Bytefield32::new(self.sent_data_amount); } else { @@ -199,18 +227,22 @@ impl TCPSession { // Therefore we drop the packet return (None, SessionAction::Drop); } + // Transition to next state self.session_state = TCPSessionState::Syncing; } TCPSessionState::Syncing => { if has_ack_flag { + // Got ACK packet if request.ack_num.val() != self.sent_data_amount + 1 { - // Incorrect ack num + // Incorrect ack num so drop return (None, SessionAction::Drop); } + // Increment ack and seq nums by 1 self.sent_data_amount += 1; self.sent_data_acked += 1; + // Transition to next state self.session_state = TCPSessionState::Established; - // Has no need for a response + // Has no need for a response, but push upstream return (None, SessionAction::PushUpstream); } else { // Waiting on the ack packet @@ -219,40 +251,58 @@ impl TCPSession { } } TCPSessionState::Established => { + // Default state is dropping let mut has_info = SessionAction::Drop; if request.ack_num.val() > self.sent_data_acked && request.ack_num.val() <= self.sent_data_amount { // We are updating our record of how much data was acked self.sent_data_acked = request.ack_num.val(); - has_info = SessionAction::PushUpstream; // todo: why do we need to push up an empty packet? + // Even with no data, we push our packet upstream because --> + // the ack needs to release the write (since it blocks until its sure the write occured) + has_info = SessionAction::PushUpstream; } + + // Check sequence number for a match and if we have data if request.seq_num.val() == self.recv_data_amount && !request.data.is_empty() { - // Sequence number matches and we have data (So we need to ack) + // Then we increment our ack_num (what we've received) self.recv_data_amount += request.data.len() as u32; + // Set response seq and ack num, and turn on ack flag response.ack_num = Bytefield32::new(self.recv_data_amount); response.seq_num = Bytefield32::new(self.sent_data_amount); response.turn_on_flags(TCP_ACK); } else { + // No data is present, so we push our packet upstream ONLY if it acked a write return (None, has_info); } } TCPSessionState::Closing => { if has_fin_flag && has_ack_flag { + // If received FIN_ACK, then set the flag self.has_sent_ack_to_fin_ack = true; if self.has_recv_ack_to_fin_ack { + // If we've received the ACK to our FIN_ACK, we are closed self.session_state = TCPSessionState::Closed; } + // We try to send an ack and push end of stream towards socket response.turn_on_flags(TCP_ACK); response_action = SessionAction::EndOfStream; } else if has_ack_flag && self.has_sent_fin_ack { + // If just has ACK but we've sent a FIN_ACK + // Then we increment our seq num and set the flag self.sent_data_amount += 1; self.has_recv_ack_to_fin_ack = true; + + // If we've sent an ack to the fin ack, we transition to closed if self.has_sent_ack_to_fin_ack { self.session_state = TCPSessionState::Closed; } + + // And push our packet upstream to release any closing sockets return (None, SessionAction::PushUpstream); } + // Increment ack_num self.recv_data_amount += 1; + // And set response ack and seq num response.ack_num = Bytefield32::new(self.recv_data_amount); response.seq_num = Bytefield32::new(self.sent_data_amount); } @@ -268,6 +318,8 @@ impl TCPSession { let start_ip = EthernetPacket::packet_size() as usize; response.ip.calculate_checksum(&data[start_ip..start_tcp]); response.calculate_checksum(&data[start_tcp..]); + + // Return result and action (Some(response), response_action) } } diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index e35000b..e5b0d9d 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -7,18 +7,26 @@ use super::{ use alloc::vec; use alloc::vec::Vec; +/// A UDP packet, implements Layer and HasChecksum (8 bytes) #[derive(Debug)] pub struct UDPPacket { - pub ip: IPPacket, // public for checksumming - pub src_port: Bytefield16, // 2 bytes - pub dest_port: Bytefield16, // 2 bytes - pub length: Bytefield16, // 2 bytes - checksum: Bytefield16, // 2 bytes - pub data: Vec, // a vector for data bytes if needed - // 10 bytes total + /// The parent packet + pub ip: IPPacket, + /// Source port + pub src_port: Bytefield16, + /// Destination port + pub dest_port: Bytefield16, + /// Length of data + pub length: Bytefield16, + /// The checksum + checksum: Bytefield16, + /// a vector for data bytes if needed + pub data: Vec, + } impl UDPPacket { + /// Create an empty packet with all 0s pub fn new() -> Self { UDPPacket { ip: IPPacket::new(), @@ -30,12 +38,18 @@ impl UDPPacket { } } - pub fn gen(ip_packet: IPPacket, src_port: u16, dest_port: u16, length: u16) -> Self { + /// Generate a UDP packet with + /// - ip_layer: is the IP layer associated with the UDP layer + /// - src_port: the source port (open on this machine) + /// - dest_port: the destination port (open on the server) + /// - length: the length of the body (how much data under it) + pub fn gen(ip_layer: IPPacket, src_port: u16, dest_port: u16, length: u16) -> Self { UDPPacket { - ip: ip_packet, + ip: ip_layer, src_port: Bytefield16::new(src_port), dest_port: Bytefield16::new(dest_port), - length: Bytefield16::new(length + 8), // size of body + 8 bytes for UDP + // size of body + 8 bytes for UDP + length: Bytefield16::new(length + 8), checksum: Bytefield16::new(0), data: Vec::new(), } @@ -43,21 +57,31 @@ impl UDPPacket { } impl Layer for UDPPacket { + /// The input layer for parse type Input = IPPacket; + + /// Parsing a UDP packet requires: + /// - ip_layer: a parsed IP packet + /// - bytevec: the data to parse, with trailing but it starts where the packet must begin fn parse(ip_layer: IPPacket, bytevec: &[u8]) -> (Self, usize, LayerType) where Self: Sized, { - let mut packet = UDPPacket::new(); // create an empty packet - // Read 14 bytes + // create an empty packet + let mut packet = UDPPacket::new(); + + // Save ip packet and read 14 bytes let mut i = 0; packet.ip = ip_layer; packet.src_port = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.dest_port = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.length = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); - assert!(i == 8); // 8 bytes + // assert 8 bytes + assert!(i == 8); + // Match the destionation port to see if its DHCP let layer_type = match packet.dest_port.val() { + // If port 68, send to DHCP layer 68 => LayerType::DHCP, _ => { // read remaining bytes and place them into the data buffer @@ -66,30 +90,37 @@ impl Layer for UDPPacket { i += 1; } assert!(i == packet.length.val() as usize); + // We are done parsing LayerType::END } }; + // Return the packet, the amount of data consumed, and the next layer type (end or DHCP) (packet, i, layer_type) } + /// Serialize the packet into a vector of bytes, ready to send over the network fn serialize(&self) -> alloc::vec::Vec { + // Create a vector and serialize it let mut res = vec![]; res.extend(self.ip.serialize()); res.extend(self.src_port.data); res.extend(self.dest_port.data); res.extend(self.length.data); res.extend(self.checksum.data); - res.extend(&self.data); // most of the time should be empty + res.extend(&self.data); assert!(res.len() == (8 + self.ip.serialize().len() + self.data.len())); res } + /// The amount of data that belongs to the packet-type fn packet_size() -> u16 { 8 } } impl HasChecksum for UDPPacket { + /// Calculate a checksum on the data and the packet + /// - will self mutate fn calculate_checksum(&mut self, data: &[u8]) { // Starting vars let mut sum: u32 = 0; diff --git a/kernel/src/task/executor.rs b/kernel/src/task/executor.rs index e9f270c..03a7fbe 100644 --- a/kernel/src/task/executor.rs +++ b/kernel/src/task/executor.rs @@ -5,12 +5,16 @@ use crossbeam_queue::ArrayQueue; use x86_64::instructions::interrupts::{self, enable_and_hlt}; pub struct Executor { + /// A tree of tasks tasks: BTreeMap, + /// A queue of task IDs (fifo scheduling) task_queue: Arc>, + /// A tree of wakers to wake tasks waker_cache: BTreeMap, } impl Executor { + /// Create a new executor pub fn new() -> Self { Executor { tasks: BTreeMap::new(), @@ -19,6 +23,7 @@ impl Executor { } } + /// Spawn a new task pub fn spawn(&mut self, task: Task) { let task_id = task.id; if self.tasks.insert(task.id, task).is_some() { @@ -27,6 +32,7 @@ impl Executor { self.task_queue.push(task_id).expect("queue full"); } + /// Run the executor forever pub fn run(&mut self) -> ! { loop { self.run_ready_tasks(); @@ -34,7 +40,7 @@ impl Executor { } } - // Wait until all tasks finish + /// Wait until all tasks finish pub fn wait(&mut self) { loop { self.run_ready_tasks(); @@ -47,6 +53,7 @@ impl Executor { } } + /// Internal function to sleep until the task_queue has stuff fn sleep_if_idle(&self) { interrupts::disable(); if self.task_queue.is_empty() { @@ -56,6 +63,7 @@ impl Executor { } } + /// Run any ready tasks until we have no tasks left fn run_ready_tasks(&mut self) { // destructure `self` to avoid borrow checker errors let Self { diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index a521817..d58a5df 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -1,59 +1,65 @@ use crate::{ network::{ - raw_socket::NetworkErrors, + errors::NetworkErrors, socket::{Socket, SocketType}, }, print, println, }; use alloc::string::String; +/// A TCP application that listens on port 6664 and echos any data sent to it +/// Used for testing or learning how to use the socket API pub async fn tcp_echo_server() { let socket_or_err = Socket::open(SocketType::TCP, 6664, 0).await; - match socket_or_err { - Ok(mut socket_gen) => { - // Allow 10 sockets - for _ in 0..10 { - let mut socket = socket_gen.listen().await.unwrap(); - loop { - let data_or_err = socket.read(0).await; - if let Ok(mut data) = data_or_err { - if data.is_empty() { - // continue if we didn't read any data - continue; - } - let user_message = String::from_utf8(data.clone()); - match user_message { - Ok(message) => { - print!("[USER] {}", message); - if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { - println!("Closing socket"); - let mut exit_message = "Closing socket...\n".as_bytes().to_vec(); - let _ = socket.write(&mut exit_message).await; - println!("Wrote final message to socket"); - socket.close().await; - println!("Closed socket"); - break; - } - } - Err(err) => println!("[USER-ERR] {:?}", err), - } - let res_or_err = socket.write(&mut data).await; - if let Err(err) = res_or_err { - println!("[ERR] (Writing): {:?}", err); - socket.close().await; - break; - } - } else if let Err(err) = data_or_err { - if err == NetworkErrors::Timeout { - continue; - } - println!("[ERR] (Reading): {:?}", err); + if let Err(err) = socket_or_err { + println!("[ERR] {:?}", err); + return; + } + let mut socket_gen = socket_or_err.unwrap(); + // Allow 10 sockets to form (we don't have threads so only listening to one at a time) + for _ in 0..10 { + // Listen for a new session socket to form + let mut socket = socket_gen.listen().await.unwrap(); + loop { + // Continously read from the socket + let data_or_err = socket.read(0).await; + if let Ok(mut data) = data_or_err { + if data.is_empty() { + // continue if we didn't read any data + continue; + } + // Get the user's message + let user_message = String::from_utf8(data.clone()); + if let Ok(message) = user_message { + // Print it out + print!("[USER] {}", message); + if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { + // If their message is quit or exit, we close the connection + let mut exit_message = "Closing socket...\n".as_bytes().to_vec(); + let _ = socket.write(&mut exit_message).await; socket.close().await; + println!("Closed socket"); break; } } + // Echo back the data from the socket + let res_or_err = socket.write(&mut data).await; + if let Err(err) = res_or_err { + // Writing - print error, close the socket, and break from the loop + println!("[ERR] (Writing): {:?}", err); + socket.close().await; + break; + } + } else if let Err(err) = data_or_err { + if err == NetworkErrors::Timeout { + // If we timed-out, we just loop again reading the socket + continue; + } + // Otherwise print the error, close the socket, and break from the loop + println!("[ERR] (Reading): {:?}", err); + socket.close().await; + break; } } - Err(err) => println!("[ERR] (Listening): {:?}", err), } } diff --git a/kernel/src/task/timeout.rs b/kernel/src/task/timeout.rs index 9606fc3..05c245c 100644 --- a/kernel/src/task/timeout.rs +++ b/kernel/src/task/timeout.rs @@ -3,26 +3,35 @@ use core::{cell::RefCell, sync::atomic::AtomicU64, task::Waker}; use lazy_static::lazy_static; use x86_64::instructions::interrupts; +/// An internal counter for how many timer interrupts occured static mut INTERRUPT_COUNTER: u64 = 0; +/// An id for a timeout #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct TimeoutID(u64); impl TimeoutID { + // Initialize a new timeout with unique ID pub fn new() -> Self { static NEXT_ID: AtomicU64 = AtomicU64::new(0); TimeoutID(NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed)) } } +/// An entry in the timeout structures struct TimeoutEntry { + /// The id of the timeout id: TimeoutID, + /// What epoch to wake at epochs: u64, + /// The waker to use to wake a task waker: Waker, + /// If the timeout was cancelled and we are just ignoring it cancelled: bool, } impl TimeoutEntry { + /// Create a new timeout entry pub fn new(id: TimeoutID, epochs: u64, waker: Waker) -> Self { TimeoutEntry { id, @@ -33,20 +42,24 @@ impl TimeoutEntry { } } +/// Define equality and ordering for timeout entry for min heap purposes based on epochs impl PartialEq for TimeoutEntry { fn eq(&self, other: &Self) -> bool { self.epochs == other.epochs } } +/// Define equality and ordering for timeout entry for min heap purposes based on epochs impl PartialOrd for TimeoutEntry { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } +/// Define equality and ordering for timeout entry for min heap purposes based on epochs impl Eq for TimeoutEntry {} +/// Define equality and ordering for timeout entry for min heap purposes based on epochs impl Ord for TimeoutEntry { fn cmp(&self, other: &Self) -> core::cmp::Ordering { self.epochs.cmp(&other.epochs).reverse() @@ -54,14 +67,18 @@ impl Ord for TimeoutEntry { } lazy_static! { + /// A binary heap of timeout entrys static ref TIMEOUT_QUEUE: spin::Mutex>> = spin::Mutex::new(BinaryHeap::new()); } +/// Read the interrupt counter +/// (it's fine is we have a data race so we don't lock, we don't care since we want an estimate of the time) pub fn read_interrupt_counter() -> u64 { unsafe { INTERRUPT_COUNTER } } -// Each epoch is ~1/18 of a second, experimentally +/// Register a new timeout entry and return its ID for cancellation +/// Each epoch is ~1/18 of a second, determined experimentally pub fn register_timeout(after_epochs: u16, waker: Waker) -> TimeoutID { let timeout_id = TimeoutID::new(); interrupts::without_interrupts(|| { @@ -75,11 +92,15 @@ pub fn register_timeout(after_epochs: u16, waker: Waker) -> TimeoutID { timeout_id } +/// Cancel a timeout with id pub fn cancel_timeout(id: TimeoutID) { + // Without interrupts (no timer interrupts) interrupts::without_interrupts(|| { + // Lock the queue and find our timeout entry let timeout_queue = TIMEOUT_QUEUE.lock(); for entry in timeout_queue.iter() { if entry.borrow().id.0 == id.0 { + // Once found, we cancel the timeout entry entry.borrow_mut().cancelled = true; break; } @@ -87,17 +108,26 @@ pub fn cancel_timeout(id: TimeoutID) { }); } -// Only run from the interrupt context +/// Poll for timeouts and wake up anything expired +/// *Only run from the interrupt context* pub fn poll_timeouts() { + // We lock the timeout queue let mut timeout_queue = TIMEOUT_QUEUE.lock(); + // Increment the counter unsafe { INTERRUPT_COUNTER += 1 }; + // And continously read timeout entrys while let Some(timeout_entry) = timeout_queue.peek() { + // If the timeout entry is expired if timeout_entry.borrow().epochs <= unsafe { INTERRUPT_COUNTER } { + // And its not cancelled if !timeout_entry.borrow().cancelled { + // Then we wake the timeout entry's waker timeout_entry.borrow().waker.wake_by_ref(); } + // We actually remove the timeout entry if its expired timeout_queue.pop(); } else { + // If we can't remove any timeouts - we break break; } } diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index 9f2e117..970df03 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -1,49 +1,58 @@ use crate::{ network::{ - raw_socket::NetworkErrors, + errors::NetworkErrors, socket::{Socket, SocketType}, }, print, println, }; use alloc::string::String; +/// A UDP application that listens on port 5554 and echos any data sent to it +/// Used for testing or learning how to use the socket API pub async fn udp_echo_server() { let socket_or_err = Socket::open(SocketType::UDP, 5554, 15).await; - match socket_or_err { - Ok(mut socket) => { - // Listen for a single connection - socket.listen().await; - loop { - let data_or_err = socket.read(0).await; - if let Ok(mut data) = data_or_err { - let user_message = String::from_utf8(data.clone()); - match user_message { - Ok(message) => { - print!("[USER] {}", message); - if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { - println!("Closing socket"); - let _ = socket.write(&mut ("Closing socket...".as_bytes().to_vec())).await; - socket.close(); - return; - } - } - Err(err) => println!("[USER-ERR] {:?}", err), - } - let res_or_err = socket.write(&mut data).await; - if let Err(err) = res_or_err { - println!("[ERR] {:?}", err); - break; - } - } else if let Err(err) = data_or_err { - if err == NetworkErrors::Timeout { - println!("[INFO] UDP-Echo ::> Socket had a timeout reading data!!"); - continue; - } - println!("[ERR] {:?}", err); - break; + // Print error in trying to establish a socket + if let Err(err) = socket_or_err { + println!("[ERR] {:?}", err); + return; + } + let mut socket = socket_or_err.unwrap(); + // Listen for a single connection + socket.listen().await; + loop { + // Loop trying to read from the socket + let data_or_err = socket.read(0).await; + if let Ok(mut data) = data_or_err { + // Parse the data we read + let user_message = String::from_utf8(data.clone()); + if let Ok(message) = user_message { + // If we can read the user message, print it + print!("[USER] {}", message); + if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { + // If exit or quit, we close the socket + println!("Closing socket"); + let _ = socket.write(&mut ("Closing socket...".as_bytes().to_vec())).await; + socket.close().await; + return; } } + // Echo the data + let res_or_err = socket.write(&mut data).await; + if let Err(err) = res_or_err { + // If we got an error -> print, close, exit + println!("[ERR] {:?}", err); + socket.close().await; + break; + } + } else if let Err(err) = data_or_err { + if err == NetworkErrors::Timeout { + // If error is just timeout, we can continue trying to read + continue; + } + // If we got an error -> print, close, exit + println!("[ERR] {:?}", err); + socket.close().await; + break; } - Err(err) => println!("[ERR] {:?}", err), } } From a528e53d199fd64cb55d14893613a2a5e545df9d Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 19 Nov 2023 00:31:46 -0500 Subject: [PATCH 19/36] tiny fixes + debugging --- kernel/src/interrupts.rs | 1 + kernel/src/network/processing.rs | 10 ++++++---- kernel/src/network/rtl8139.rs | 3 +++ kernel/src/network/tcp.rs | 2 +- kernel/src/task/tcp_echo.rs | 1 + 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index 7d1340f..127c100 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -57,6 +57,7 @@ impl InterruptHandler { /// Static function for unblocking an interrupt by irq_num pub fn unblock_irq(irq_num: u8) { + // todo: disable interrupts, maybe this is why we sometimes deadlock let mut locked_pics = PICS.lock(); let data = unsafe { locked_pics.read_masks() }; // set the irq bit to 0 diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 2800863..95b3695 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -16,7 +16,7 @@ use crate::{ tcp::TCPPacket, tcp_session::{SessionAction, TCPSession}, errors::NetworkErrors, }, - println, + println, print, }; use core::{ @@ -92,6 +92,7 @@ pub async fn init_packet_processing() { let mut raw_packets = PendingProcessingStream::new(); // Get the next piece of data the interrupt handler drew from the card while let Some(pkt_data) = raw_packets.next().await { + print!("."); // Parse it let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); if amount_parsed_and_pkt.1.get_type() == LayerType::ERR || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP { @@ -102,9 +103,9 @@ pub async fn init_packet_processing() { assert!(amount_parsed_and_pkt.0 == pkt_data.len() || pkt_data.len() < 64); // Get the driver configuration disable_network_interrupts(); - let mut net_dev = NET_INFO.lock(); + let mut rtl_dev_info_guard = NET_INFO.lock(); // Get the device fields - let rtl_dev_info = net_dev.get_mut().unwrap(); + let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); let mut rtl_dev_config = NET_INFO.config.lock(); match amount_parsed_and_pkt.1 { PacketData::ARP(arp) => { @@ -183,6 +184,7 @@ pub async fn init_packet_processing() { if !rtl_dev_config.tcp_sessions.contains_key(&session_key) { if (tcp.get_flags() & TCP_SYN) == 0 { // Don't create new session when there is no request for syncing + drop(rtl_dev_info_guard); enable_network_interrupts(); return; } @@ -238,7 +240,7 @@ pub async fn init_packet_processing() { _ => {} // ignore other packets } // Release the guard and enable interrupts - drop(net_dev); + drop(rtl_dev_info_guard); enable_network_interrupts(); } } diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 37b1230..3e96f62 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -18,6 +18,7 @@ use crate::interrupts::IDT; use crate::network::constants::{CAPR, CR, CR_BUFE, CR_RE, CR_TE, RX_BUFFER_SIZE, RX_READ_PTR_MASK}; use crate::network::raw_array::WrappingRawArray; +use crate::print; use super::constants::{INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; use super::errors::NetworkErrors; @@ -80,7 +81,9 @@ pub fn enable_network_interrupts() { pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) { // Try to get the device info + print!("<"); let mut net_dev = unsafe { NET_INFO.lock_no_disable() }; + print!(">"); if net_dev.is_none() { panic!("RTL_INFO is undefined!"); } diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index c99ecd2..707af71 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -133,7 +133,7 @@ impl Layer for TCPPacket { packet.sliding_window = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.urgent = Bytefield16::read_inc(&bytevec[i..], &mut i); - // read remaining bytes of heaer and place them into the options buffer + // read remaining bytes of header and place them into the options buffer for _ in 0..(packet.get_header_offset() - 20) { packet.options.push(bytevec[i]); i += 1; diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index d58a5df..f8a533b 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -62,4 +62,5 @@ pub async fn tcp_echo_server() { } } } + println!("[INFO] Finished up all allocated sockets for echoing"); } From 5fe4650082994f7c736de58e06ae1eac4ca555a1 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 21 Nov 2023 13:30:59 -0500 Subject: [PATCH 20/36] Refactor to be all constants --- kernel/src/interrupts.rs | 4 +- kernel/src/network/README.md | 1 - kernel/src/network/TODO.md | 4 +- kernel/src/network/constants.rs | 19 +++++++--- kernel/src/network/processing.rs | 1 - kernel/src/network/rtl8139.rs | 63 +++++++++++++++----------------- 6 files changed, 47 insertions(+), 45 deletions(-) diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index 127c100..c2c3f78 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -83,8 +83,8 @@ impl InterruptHandler { /// Register an interrupt handler /// Used by the RTL card to receive device interrupts pub fn register_irq(&mut self, irq_num: usize, handler: InterruptHandlerFunc) { - println!("Registered Handler @ {}", irq_num + 32); - self.idt[irq_num + 32].set_handler_fn(handler); + println!("Registered Handler @ {}", irq_num + PIC_1_OFFSET as usize); + self.idt[irq_num + PIC_1_OFFSET as usize].set_handler_fn(handler); unsafe { self.idt.load_unsafe() }; println!("Registered IRQ @ {}", irq_num); } diff --git a/kernel/src/network/README.md b/kernel/src/network/README.md index b234e74..8e62928 100644 --- a/kernel/src/network/README.md +++ b/kernel/src/network/README.md @@ -16,7 +16,6 @@ TODO [] Verify other parts of the packet [] Fix synchronization to be much cleaner [] Clean up ugly stuff -[] Refactor to be all constants [] search for todo and fix thoses [] Refactor to include more documentation on the network module diff --git a/kernel/src/network/TODO.md b/kernel/src/network/TODO.md index 57101fc..c896a09 100644 --- a/kernel/src/network/TODO.md +++ b/kernel/src/network/TODO.md @@ -7,8 +7,6 @@ * Fix synchronization to be much cleaner * Fix checksums to be baked-in * Clean up ugly stuff -* DHCP parse additional options -* Refactor to be all constants * Search for todo and fix thoses * Benchmarking -* Rename ip and mac address to a standard \ No newline at end of file +* Rename ip and mac address to a standard diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index ce5e374..d148071 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -20,13 +20,22 @@ pub const TRANSMIT_REG: [u32; 4] = [0x20, 0x24, 0x28, 0x2C]; // 4 transmit regis pub const TRANSMIT_CMD: [u32; 4] = [0x10, 0x14, 0x18, 0x1C]; // 4 cmd registers pub const INTERRUPT_MASK: u16 = 0x01 | 0x04 | 0x10 | 0x08 | 0x02; // interrupt mask pub const RX_BUFFER_SIZE: u16 = 8192; // how big the buffer is -pub const CR_RST: u16 = 0x10; // Reset, set to 1 to invoke S/W reset, held to 1 while resetting -pub const CR_RE: u8 = 0x08; // Reciever Enable, enables receiving -pub const CR_TE: u8 = 0x04; // Transmitter Enable, enables transmitting -pub const CR_BUFE: u8 = 0x01; // Rx buffer is empty -pub const CR: u32 = 0x37; // command register (1byte) +pub const CONFIG_1_REG: u16 = 0x52; // has LWAKE + LWPTN, will allow power on device if active high (0x00) +pub const CMD_REG_RST: u8 = 0x10; // Reset, set to 1 to invoke S/W reset, held to 1 while resetting +pub const CMD_REG_RE: u8 = 0x08; // Reciever Enable, enables receiving +pub const CMD_REG_TE: u8 = 0x04; // Transmitter Enable, enables transmitting +pub const CMD_REG_BUFE: u8 = 0x01; // Rx buffer is empty +pub const CMD_REG: u32 = 0x37; // command register (1byte) pub const CAPR: u32 = 0x38; // current address of packet read (2byte, C mode, initial value 0xFFF0) pub const RX_READ_PTR_MASK: u16 = !0x3; // Used to align to 8 bytes in RTL8139 driver +pub const IMR_REG: u16 = 0x3C; // Interrupt mask register +pub const ISR_REG: u16 = 0x3E; // Interrupt service register +pub const RX_START_REG: u16 = 0x30; // Receive buffer start register +pub const RX_BUF_REG: u16 = 0x44; // receive buffer config register +pub const RX_BROADCAST: u32 = 0x08; // Accept broadcast packets sent to mac ff:ff:ff:ff:ff:ff +pub const RX_MULTICAST: u32 = 0x04; // Accept multicast packets +pub const RX_PHYSICAL_MATCH: u32 = 0x02; // Accept physical matches +pub const RX_PROMISCOUS: u32 = 0x01; // Accept all packets // TCP Constants pub const TCP_FIN: u8 = 0x1; // TCP FIN flag (gracefully closing connection) diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 95b3695..9d2a05f 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -92,7 +92,6 @@ pub async fn init_packet_processing() { let mut raw_packets = PendingProcessingStream::new(); // Get the next piece of data the interrupt handler drew from the card while let Some(pkt_data) = raw_packets.next().await { - print!("."); // Parse it let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); if amount_parsed_and_pkt.1.get_type() == LayerType::ERR || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP { diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 3e96f62..c63b1fc 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -14,13 +14,16 @@ use x86_64::{ PhysAddr, VirtAddr, }; -use crate::interrupts::IDT; +use crate::interrupts::{IDT, PIC_1_OFFSET}; -use crate::network::constants::{CAPR, CR, CR_BUFE, CR_RE, CR_TE, RX_BUFFER_SIZE, RX_READ_PTR_MASK}; +use crate::network::constants::{ + CAPR, CMD_REG, CMD_REG_BUFE, CMD_REG_RE, CMD_REG_RST, CMD_REG_TE, RX_BROADCAST, RX_BUFFER_SIZE, RX_BUF_REG, RX_READ_PTR_MASK, + RX_START_REG, RX_MULTICAST, RX_PHYSICAL_MATCH, RX_PROMISCOUS, CONFIG_1_REG, +}; use crate::network::raw_array::WrappingRawArray; use crate::print; -use super::constants::{INTERRUPT_MASK, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; +use super::constants::{IMR_REG, INTERRUPT_MASK, ISR_REG, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; use super::errors::NetworkErrors; use super::processing::add_pkt_data; use super::tcp_session::TCPSession; @@ -63,7 +66,7 @@ static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = spin::Mutex:: pub fn disable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; if data.get() == 0 { - let mut port_imr = Port::::new((unsafe { IO_BASE } + 0x3C) as u16); + let mut port_imr = Port::::new((unsafe { IO_BASE } as u16) + IMR_REG); unsafe { port_imr.write(0x0) }; } data.inc(); @@ -74,16 +77,14 @@ pub fn enable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; data.dec(); if data.get() == 0 { - let mut port_imr = Port::::new((unsafe { IO_BASE } + 0x3C) as u16); + let mut port_imr = Port::::new((unsafe { IO_BASE } as u16) + IMR_REG); unsafe { port_imr.write(INTERRUPT_MASK) }; } } pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) { // Try to get the device info - print!("<"); let mut net_dev = unsafe { NET_INFO.lock_no_disable() }; - print!(">"); if net_dev.is_none() { panic!("RTL_INFO is undefined!"); } @@ -94,17 +95,17 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) if io_base.is_none() || irq.is_none() { println!("[ERR] Handling packet - missing data"); unsafe { - PICS.lock().notify_end_of_interrupt(irq.unwrap() + 32); + PICS.lock().notify_end_of_interrupt(irq.unwrap() + PIC_1_OFFSET); } return; } // stop interrupts to the device - let mut port_imr = Port::::new((io_base.unwrap() + 0x3C) as u16); + let mut port_imr = Port::::new(io_base.unwrap() as u16 + IMR_REG); unsafe { port_imr.write(0x0) }; // Read the ISR register - let mut port_isr = Port::::new((io_base.unwrap() + 0x3E) as u16); + let mut port_isr = Port::::new((io_base.unwrap() as u16) + ISR_REG); let status = unsafe { port_isr.read() }; // Reset the ISR register unsafe { port_isr.write(0x05) }; @@ -121,7 +122,7 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) // Notify end of interrupt unsafe { - PICS.lock().notify_end_of_interrupt(irq.unwrap() + 32); + PICS.lock().notify_end_of_interrupt(irq.unwrap() + PIC_1_OFFSET); } } @@ -131,10 +132,10 @@ fn recv_packet(rtl_dev_info: &RTL8139) { panic!("RTL8139 is not initialized properly"); } // Make sure buffer isn't empty - let cmd_reg = (rtl_dev_info.config.io_base.unwrap() + CR) as u16; + let cmd_reg = (rtl_dev_info.config.io_base.unwrap() + CMD_REG) as u16; let mut cmd_port = Port::::new(cmd_reg); - // while unsafe { cmd_port.read() } & CR_BUFE == 0x0 { - if unsafe { cmd_port.read() } & CR_BUFE == 0x0 { + // while unsafe { cmd_port.read() } & CMD_REG_BUFE == 0x0 { + if unsafe { cmd_port.read() } & CMD_REG_BUFE == 0x0 { // Receive a packet by reading the buffer // ? Reading the buffer is naturally unsafe? Is there a better way? let virtual_buffer_recv: VirtAddr = @@ -322,44 +323,40 @@ impl RTL8139 { self.mac_address = Some(mac_addr); println!("[INFO] MAC address is {:#10x}", mac_addr); - // turn on the card - let addr = self.config.io_base.unwrap() + 0x52; - let mut port_config_1 = Port::::new(addr as u16); + // turn on the card by setting config 1 to 0x00 + let config_1_reg = self.config.io_base.unwrap() as u16 + CONFIG_1_REG; + let mut port_config_1 = Port::::new(config_1_reg); unsafe { port_config_1.write(0x00) }; // Performing software reset - let cmd_reg = self.config.io_base.unwrap() + CR; + let cmd_reg = self.config.io_base.unwrap() + CMD_REG; let mut port_rst = Port::::new(cmd_reg as u16); unsafe { - port_rst.write(0x10); - while port_rst.read() & 0x10 != 0 {} // spin until we observe reset is over + port_rst.write(CMD_REG_RST); + while port_rst.read() & CMD_REG_RST != 0 {} // spin until we observe reset is over }; println!("[INFO] RTL8139 has been reset!"); // Enable receiver and transmitter let mut port_recv_transmit = Port::::new(cmd_reg as u16); - unsafe { port_recv_transmit.write(CR_TE | CR_RE) }; + unsafe { port_recv_transmit.write(CMD_REG_TE | CMD_REG_RE) }; // Configuring receive buffer - let rcr_reg = self.config.io_base.unwrap() + 0x44; - let mut rcr = Port::::new(rcr_reg as u16); + let rcr_reg = self.config.io_base.unwrap() as u16 + RX_BUF_REG; + let mut rcr = Port::::new(rcr_reg); unsafe { - let broadcast = 0x08; // Accept broadcast packets sent to mac ff:ff:ff:ff:ff:ff - let multicast = 0x04; // Accept multicast packets - let physical_match = 0x02; // Accept physical matches - let promiscous = 0x01; // Accept all packets - // (1 << 7) is the WRAP bit, 0xf is broadcast, multicast, physical match, accept all packets - rcr.write(physical_match | multicast | broadcast | promiscous); + // No wrap, ring buffer, accept all packets + rcr.write(RX_PHYSICAL_MATCH | RX_MULTICAST | RX_BROADCAST | RX_PROMISCOUS); }; // Init receive buffer - let rcv_buf_reg = self.config.io_base.unwrap() + 0x30; - let mut rcv_buffer = Port::::new(rcv_buf_reg as u16); + let rcv_buf_reg = self.config.io_base.unwrap() as u16 + RX_START_REG; + let mut rcv_buffer = Port::::new(rcv_buf_reg); unsafe { rcv_buffer.write(self.recv_buffer.unwrap().as_u64() as u32) }; // Set IMR + ISR - let imr_reg = self.config.io_base.unwrap() + 0x3C; - let mut imr = Port::::new(imr_reg as u16); + let imr_reg = self.config.io_base.unwrap() as u16 + IMR_REG; + let mut imr = Port::::new(imr_reg); unsafe { imr.write(INTERRUPT_MASK) }; // Enable interrupts From 4d139840435b62b4929e2998f728c9fd327df39e Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Wed, 22 Nov 2023 10:40:31 -0500 Subject: [PATCH 21/36] Verifying checksums and renaming --- kernel/src/network/README.md | 1 - kernel/src/network/TODO.md | 5 +- kernel/src/network/arp.rs | 56 +++--- kernel/src/network/bytefield.rs | 22 ++- kernel/src/network/dhcp.rs | 48 +++-- kernel/src/network/ethernet.rs | 16 +- kernel/src/network/init.rs | 9 +- kernel/src/network/ip.rs | 109 +++++------ kernel/src/network/layer.rs | 91 ++++++---- kernel/src/network/network_query.rs | 4 +- kernel/src/network/processing.rs | 270 +++++++++++++++------------- kernel/src/network/rtl8139.rs | 1 - kernel/src/network/socket.rs | 29 ++- kernel/src/network/tcp.rs | 102 +++++++---- kernel/src/network/tcp_session.rs | 26 +-- kernel/src/network/udp.rs | 93 ++++++---- 16 files changed, 510 insertions(+), 372 deletions(-) diff --git a/kernel/src/network/README.md b/kernel/src/network/README.md index 8e62928..8d8eca8 100644 --- a/kernel/src/network/README.md +++ b/kernel/src/network/README.md @@ -12,7 +12,6 @@ TODO [x] Async IO [x] Timeouts [] Refactor so that all of networking is tested -[] Refactor to verify checksums [] Verify other parts of the packet [] Fix synchronization to be much cleaner [] Clean up ugly stuff diff --git a/kernel/src/network/TODO.md b/kernel/src/network/TODO.md index c896a09..5660b88 100644 --- a/kernel/src/network/TODO.md +++ b/kernel/src/network/TODO.md @@ -2,11 +2,8 @@ * Refactor so that all of networking is tested * Refactor to include more documentation on the network module -* Refactor to verify checksums -* Verify other parts of the packet +* Verify other parts of the packet (in rtl) * Fix synchronization to be much cleaner -* Fix checksums to be baked-in * Clean up ugly stuff * Search for todo and fix thoses * Benchmarking -* Rename ip and mac address to a standard diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs index 05efd9f..ed17a0b 100644 --- a/kernel/src/network/arp.rs +++ b/kernel/src/network/arp.rs @@ -22,13 +22,13 @@ pub struct ArpPacket { /// 1 for request, 2 for reply operation: Bytefield16, /// The sender's mac address - pub sender_mac: Bytefield48, + pub src_mac: Bytefield48, /// The sender's IP address - pub sender_ip: Bytefield32, + pub src_ip: Bytefield32, /// The mac address of the receiver (0 if a request, this is the question field) - pub recp_mac: Bytefield48, + pub dest_mac: Bytefield48, /// The recepient IP, this is also part of the question if a request - pub recp_ip: Bytefield32, + pub dest_ip: Bytefield32, } impl ArpPacket { @@ -41,22 +41,22 @@ impl ArpPacket { hardware_address_length: 0, protocol_address_length: 0, operation: Bytefield16::new(0), - sender_mac: Bytefield48::new(0), - sender_ip: Bytefield32::new(0), - recp_mac: Bytefield48::new(0), - recp_ip: Bytefield32::new(0), + src_mac: Bytefield48::new(0), + src_ip: Bytefield32::new(0), + dest_mac: Bytefield48::new(0), + dest_ip: Bytefield32::new(0), } } /// Generate a ARP packet with /// - eth_layer: is the ethernet frame associated with the packet - /// - source_ip: is the machine's IP address - /// - recp_ip: is the destination's IP address + /// - src_ip: is the machine's IP address + /// - dest_ip: is the destination's IP address /// - is_req: is true if the packet is a request and false if the packet is a response - pub fn gen(eth_layer: EthernetPacket, source_ip: u32, recp_ip: u32, is_req: bool) -> Self { - // Extract the recp_mac and sender_mac from the ethernet layer - let recp_mac = eth_layer.dest_mac; - let sender_mac = eth_layer.src_mac; + pub fn gen(eth_layer: EthernetPacket, src_ip: u32, dest_ip: u32, is_req: bool) -> Self { + // Extract the dest_mac and src_mac from the ethernet layer + let dest_mac = eth_layer.dest_mac; + let src_mac = eth_layer.src_mac; // Assert the eth layer matches us assert!(eth_layer.packet_type == EthType::Arp); // Construct the arp packet @@ -67,10 +67,10 @@ impl ArpPacket { hardware_address_length: 6, // ethernet is the value 6 protocol_address_length: 4, // ethernet is the value 4 operation: Bytefield16::new(if is_req { 1 } else { 2 }), // Request is 1, Reply is 2 - sender_mac, - sender_ip: Bytefield32::new(source_ip), - recp_mac, - recp_ip: Bytefield32::new(if is_req { 0 } else { recp_ip }), + src_mac, + src_ip: Bytefield32::new(src_ip), + dest_mac, + dest_ip: Bytefield32::new(if is_req { 0 } else { dest_ip }), } } } @@ -94,14 +94,14 @@ impl Layer for ArpPacket { packet.hardware_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.protocol_address_length = Bytefield8::read_inc(&bytevec[i..], &mut i).val(); packet.operation = Bytefield16::read_inc(&bytevec[i..], &mut i); - packet.sender_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); - packet.sender_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); - packet.recp_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); - packet.recp_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.src_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); + packet.src_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.dest_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); + packet.dest_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); assert!(i == 28); // Arp packet should be 28 bytes - + // Return the packet, the amount of data consumed, and the next layer type (end of parse) - (packet, i, LayerType::END) + (packet, i, LayerType::End) } /// Serialize the packet into a vector of bytes, ready to send over the network @@ -114,10 +114,10 @@ impl Layer for ArpPacket { res.push(self.hardware_address_length); res.push(self.protocol_address_length); res.extend(self.operation.data); - res.extend(self.sender_mac.data); - res.extend(self.sender_ip.data); - res.extend(self.recp_mac.data); - res.extend(self.recp_ip.data); + res.extend(self.src_mac.data); + res.extend(self.src_ip.data); + res.extend(self.dest_mac.data); + res.extend(self.dest_ip.data); res } diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs index 30ab0e1..2f53cb9 100644 --- a/kernel/src/network/bytefield.rs +++ b/kernel/src/network/bytefield.rs @@ -1,4 +1,4 @@ -use core::ops::{Index, IndexMut}; +use core::{ops::{Index, IndexMut}, fmt}; // N.B.: Bytefields will swap the endianness of the values when created // --> therefore serializing will create network byte order (big endian) Bytefields @@ -8,7 +8,7 @@ use core::ops::{Index, IndexMut}; /// A structure to represent data as an array of bytes /// Will store in both host and network byte-order. It is dependent on it's construction -#[derive(Debug, Clone, Copy)] +#[derive(Clone, Copy)] pub struct Bytefield { pub data: [u8; N], } @@ -53,6 +53,24 @@ impl IndexMut for Bytefield { } } +/// Wrote custom debug print +impl fmt::Debug for Bytefield { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if write!(f, "Bytefield<{}>(data: [", N).is_err() { + return fmt::Result::Err(fmt::Error); + } + for data in self.data.iter() { + if write!(f, "{:#4x}, ", data).is_err() { + return fmt::Result::Err(fmt::Error); + } + } + if write!(f, "])").is_err() { + return fmt::Result::Err(fmt::Error); + } + fmt::Result::Ok(()) + } +} + macro_rules! bytefield_int { ($t:ident, $int:ident, $size:literal) => { pub type $t = Bytefield<$size>; diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 69b924c..18549e7 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -2,8 +2,9 @@ use alloc::vec; use super::{ bytefield::{Bytefield128, Bytefield16, Bytefield32, Bytefield48, Bytefield8}, + ethernet::EthernetPacket, ip::IPPacket, - layer::{HasChecksum, Layer, LayerType}, + layer::{calculate_checksum_inner, HasChecksum, Layer, LayerType}, udp::UDPPacket, }; @@ -107,7 +108,7 @@ impl DHCPPacket { client_hardware_address[i] = mac[i]; } // Generate the DHCP packet - let mut dhcp = DHCPPacket { + DHCPPacket { udp: udp_layer, op_code: 1, // 1 for is request hardware_type: 1, // ethernet is 1 @@ -125,14 +126,7 @@ impl DHCPPacket { file: [0; 128], // 128 bytes options: [0; 64], // todo: 64 bytes (can be more) // 300 bytes total - }; - // Calculate the checksums - let data = dhcp.serialize(); - let start_udp = data.len() - (DHCPPacket::packet_size() as usize + UDPPacket::packet_size() as usize); - let start_ip = start_udp - (IPPacket::packet_size() as usize); - dhcp.udp.ip.calculate_checksum(&data[start_ip..start_udp]); - dhcp.udp.calculate_checksum(&data[start_udp..]); - dhcp + } } } @@ -175,7 +169,7 @@ impl Layer for DHCPPacket { i += left_to_parse as usize; assert!(i >= 300); // 300 bytes // Return the packet, the amount of data consumed, and the next layer type (end of parse) - (packet, i, LayerType::END) + (packet, i, LayerType::End) } /// Serialize the packet into a vector of bytes, ready to send over the network @@ -208,3 +202,35 @@ impl Layer for DHCPPacket { 300 } } + +impl HasChecksum for DHCPPacket { + fn calculate_checksum(&mut self) { + // Starting vars + let mut sum: u32 = 0; + + // First we do the IP as a pseduo header + let ip = &self.udp.ip; + sum += (ip.src_ip.data[0] as u32) | (ip.src_ip.data[1] as u32) << 8; + sum += (ip.src_ip.data[2] as u32) | (ip.src_ip.data[3] as u32) << 8; + sum += (ip.dest_ip.data[0] as u32) | (ip.dest_ip.data[1] as u32) << 8; + sum += (ip.dest_ip.data[2] as u32) | (ip.dest_ip.data[3] as u32) << 8; + + // Sum protocol and length + let protocol = Bytefield16::new(ip.protocol as u16); + sum += (protocol.data[0] as u32) | (protocol.data[1] as u32) << 8; + sum += (self.udp.length.data[0] as u32) | (self.udp.length.data[1] as u32) << 8; + + // Sum the body + self.udp.checksum = Bytefield16::new(0); + let data = self.serialize(); + let start_udp = IPPacket::packet_size() + EthernetPacket::packet_size(); + let res = calculate_checksum_inner(&data[start_udp as usize..], sum); + + // Save checksum + self.udp.checksum = Bytefield16::new(res); + } + + fn verify_checksum(&mut self) -> bool { + self.udp.ip.verify_checksum() + } +} diff --git a/kernel/src/network/ethernet.rs b/kernel/src/network/ethernet.rs index b398877..f0de3f5 100644 --- a/kernel/src/network/ethernet.rs +++ b/kernel/src/network/ethernet.rs @@ -51,13 +51,13 @@ impl EthernetPacket { } /// Generate a Ethernet packet with - /// - destination_mac: the destination mac address - /// - source_mac: the source mac address + /// - dest_mac: the destination mac address + /// - src_mac: the source mac address /// - packet_type: the class of packet to send - pub fn gen(destination_mac: u64, source_mac: u64, packet_type: EthType) -> Self { + pub fn gen(dest_mac: u64, src_mac: u64, packet_type: EthType) -> Self { EthernetPacket { - dest_mac: Bytefield48::new(destination_mac), - src_mac: Bytefield48::new(source_mac), + dest_mac: Bytefield48::new(dest_mac), + src_mac: Bytefield48::new(src_mac), packet_type, } } @@ -83,10 +83,10 @@ impl Layer for EthernetPacket { assert!(i == 14); // assert 14 bytes // Match on the packet type to make sure we don't have an error let layer_type = match &packet.packet_type { - EthType::Arp => LayerType::ARP, + EthType::Arp => LayerType::Arp, EthType::IPv4 => LayerType::IP, - EthType::RoCE => LayerType::ERR, - EthType::Unknown => LayerType::ERR, + EthType::RoCE => LayerType::Err, + EthType::Unknown => LayerType::Err, }; // Return the packet, the amount of data consumed, and the next layer type (based on EthType) (packet, i, layer_type) diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index b6d641d..6ad1bc3 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -1,6 +1,7 @@ use futures_util::StreamExt; use super::constants::{BROADCAST_ADDR, BROADCAST_MAC, DHCP_CLIENT_PORT}; +use super::layer::HasChecksum; use super::processing; use crate::network::dhcp::DHCPPacket; use crate::network::ethernet::{EthType, EthernetPacket}; @@ -31,9 +32,11 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { // create dhcp initial request let eth = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::IPv4); let ip_size = DHCPPacket::packet_size() + UDPPacket::packet_size(); - let ip = IPPacket::gen(eth, ip_size, Protocol::UDP, 0x0, BROADCAST_ADDR); + let ip = IPPacket::gen(eth, ip_size, Protocol::Udp, 0x0, BROADCAST_ADDR); let udp = UDPPacket::gen(ip, DHCP_CLIENT_PORT as u16, DHCP_SERVER_PORT as u16, DHCPPacket::packet_size()); - let dhcp = DHCPPacket::gen(udp, None, rtl_dev_info.mac_address.unwrap()); + let mut dhcp = DHCPPacket::gen(udp, None, rtl_dev_info.mac_address.unwrap()); + dhcp.udp.ip.calculate_checksum(); + dhcp.calculate_checksum(); let packet_data = dhcp.serialize(); // Send the first BOOTP packet @@ -78,7 +81,7 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { disable_network_interrupts(); let mut rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_mut().unwrap(); - if pkt_data.get_type() == LayerType::DHCP { + if pkt_data.get_type() == LayerType::Dhcp { // With the driver object, record the information from the DHCP server let dhcp_res = pkt_data.unwrap_dhcp(); rtl_dev_info.my_ip_address = Some(dhcp_res.my_ip.val()); diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index d63b01c..c75fe4a 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -4,7 +4,7 @@ use alloc::vec::Vec; use super::{ bytefield::{Bytefield16, Bytefield32, Bytefield8}, ethernet::EthernetPacket, - layer::{HasChecksum, Layer, LayerType}, + layer::{calculate_checksum_inner, HasChecksum, Layer, LayerType}, }; /// Protocol for IP @@ -12,9 +12,9 @@ use super::{ #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum Protocol { - ICMP = 1, - TCP = 6, - UDP = 17, + Icmp = 1, + Tcp = 6, + Udp = 17, Unsupported = 255, } @@ -22,9 +22,9 @@ impl Protocol { /// Parse a u8 into a Protocol enum pub fn from(data: u8) -> Self { match data { - 1 => Self::ICMP, - 6 => Self::TCP, - 17 => Self::UDP, + 1 => Self::Icmp, + 6 => Self::Tcp, + 17 => Self::Udp, _ => Self::Unsupported, } } @@ -56,25 +56,25 @@ pub struct IPPacket { /// The parent packet pub eth: EthernetPacket, /// IP version (hardcoded) - version_hlen: u8, + pub version_hlen: u8, /// Can increase urgency. Unused - type_of_service: u8, + pub type_of_service: u8, /// The entire size of IP packet + data size pub total_length: Bytefield16, /// A unique ID to help de-fragmentation of the packet - identification: Bytefield16, + pub identification: Bytefield16, /// Flags to prevent fragmentation (we are fine with it) - flags_fragment_offset: Bytefield16, + pub flags_fragment_offset: Bytefield16, /// How many router hops before we drop the packet - ttl: u8, + pub ttl: u8, /// The protocol used to identify the data pub protocol: Protocol, /// The checksum to verify contents of the packet pub checksum: Bytefield16, /// The sender's IP address - pub source_ip: Bytefield32, + pub src_ip: Bytefield32, /// The recepient's IP address - pub destination_ip: Bytefield32, + pub dest_ip: Bytefield32, } impl IPPacket { @@ -90,8 +90,8 @@ impl IPPacket { ttl: 0, protocol: Protocol::Unsupported, checksum: Bytefield16::new(0), - source_ip: Bytefield32::new(0), - destination_ip: Bytefield32::new(0), + src_ip: Bytefield32::new(0), + dest_ip: Bytefield32::new(0), } } @@ -100,8 +100,8 @@ impl IPPacket { /// - data_length: is the data's size /// - protocol: the protocol of the packet (TCP/UDP) /// - src_ip: the sender's IP address - /// - dst_ip: the destination's IP address - pub fn gen(eth_layer: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dst_ip: u32) -> Self { + /// - dest_ip: the destination's IP address + pub fn gen(eth_layer: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dest_ip: u32) -> Self { // Generate a unique ID for the packet let identification = unsafe { let mut id_gen = ID_GEN.lock(); @@ -120,8 +120,8 @@ impl IPPacket { ttl: 120, // 120 - our packet needs to make it there protocol, checksum: Bytefield16::new(0), - source_ip: Bytefield32::new(src_ip), - destination_ip: Bytefield32::new(dst_ip), + src_ip: Bytefield32::new(src_ip), + dest_ip: Bytefield32::new(dest_ip), } } } @@ -137,7 +137,8 @@ impl Layer for IPPacket { where Self: Sized, { - let mut packet = IPPacket::new(); // create an empty packet + // create an empty packet + let mut packet = IPPacket::new(); // Save ethernet packet and read 20 bytes let mut i = 0; packet.eth = eth_layer; @@ -150,16 +151,16 @@ impl Layer for IPPacket { let protocol = Bytefield8::read_inc(&bytevec[i..], &mut i); packet.protocol = Protocol::from(protocol.val()); packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); - packet.source_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); - packet.destination_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.src_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); + packet.dest_ip = Bytefield32::read_inc(&bytevec[i..], &mut i); // assert 20 bytes assert!(i == 20); // Match the protocol to determine next layer let layer_type = match packet.protocol { - Protocol::ICMP => LayerType::ICMP, // unsupported right now - Protocol::TCP => LayerType::TCP, - Protocol::UDP => LayerType::UDP, - Protocol::Unsupported => LayerType::ERR, + Protocol::Icmp => LayerType::Icmp, // unsupported right now + Protocol::Tcp => LayerType::Tcp, + Protocol::Udp => LayerType::Udp, + Protocol::Unsupported => LayerType::Err, }; // Return the packet, the amount of data consumed, and the next layer type (end of parse) @@ -179,8 +180,8 @@ impl Layer for IPPacket { res.push(self.ttl); res.push(self.protocol as u8); res.extend(self.checksum.data); - res.extend(self.source_ip.data); - res.extend(self.destination_ip.data); + res.extend(self.src_ip.data); + res.extend(self.dest_ip.data); assert!(res.len() == (20 + self.eth.serialize().len())); res } @@ -194,35 +195,35 @@ impl Layer for IPPacket { impl HasChecksum for IPPacket { /// Calculate a checksum on the data and the packet /// - will self mutate - fn calculate_checksum(&mut self, data: &[u8]) { - // Starting vars - let mut sum: u32 = 0; - + fn calculate_checksum(&mut self) { // Sum the body self.checksum = Bytefield16::new(0); - let mut ptr = 0; - let mut ip_len = data.len(); - while ip_len > 1 { - sum += (data[ptr] as u32) | ((data[ptr + 1] as u32) << 8); - ip_len -= 2; - ptr += 2; - } - - if data.len() % 2 == 1 { - // Add the padding if the packet length is odd - sum += (data[ptr] as u32) << 8; - } + let data = self.serialize(); + let start_ip = data.len() - IPPacket::packet_size() as usize; + let res = calculate_checksum_inner(&data[start_ip..], 0); - // Add the carries - while sum > 0xFFFF { - sum = (sum & 0xFFFF) + (sum >> 16); - } + // Save checksum + self.checksum = Bytefield16::new(res); + } - // One's complement and swap the bytes because we did our sum in big endian - // (and the bytefield will try to convert to big endian) - let res = u16::swap_bytes(!sum as u16); + /// Check if the checksum is valid in the packet + fn verify_checksum(&mut self) -> bool { + let mut ip: IPPacket = IPPacket { + eth: EthernetPacket::new(), + version_hlen: self.version_hlen, + type_of_service: self.type_of_service, + total_length: self.total_length.swapped_endianness(), + identification: self.identification.swapped_endianness(), + flags_fragment_offset: self.flags_fragment_offset.swapped_endianness(), + ttl: self.ttl, + protocol: self.protocol, + checksum: Bytefield16::new(0), + src_ip: self.src_ip.swapped_endianness(), + dest_ip: self.dest_ip.swapped_endianness(), + }; + ip.calculate_checksum(); - // Return the one's complement of sum - self.checksum = Bytefield16::new(res); + // Return if the checksum is a match + ip.checksum.val() == self.checksum.swapped_endianness().val() } } diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index f9995b7..b064d28 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -35,7 +35,7 @@ impl EmptyLayer { impl Layer for EmptyLayer { type Input = EmptyLayer; - + /// Shouldn't be used fn parse(_upper: EmptyLayer, _bytevec: &[u8]) -> (Self, usize, LayerType) where @@ -57,30 +57,61 @@ impl Layer for EmptyLayer { pub trait HasChecksum { /// Calculate the checksum and self mutate - fn calculate_checksum(&mut self, data: &[u8]); + fn calculate_checksum(&mut self); + + fn verify_checksum(&mut self) -> bool; +} + +pub fn calculate_checksum_inner(body: &[u8], prev_sum: u32) -> u16 { + let mut body_len = body.len(); + let mut sum = prev_sum; + + // Sum the body + let mut ptr = 0; + while body_len > 1 { + sum += (body[ptr] as u32) | ((body[ptr + 1] as u32) << 8); + body_len -= 2; + ptr += 2; + } + + if body_len % 2 == 1 { + // Add the padding if the packet length is odd + sum += body[ptr] as u32; + } + + // Add the carries + while sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16); + } + + // One's complement and swap the bytes because we did our sum in big endian + // (and the bytefield will try to convert to big endian) + let res = u16::swap_bytes(!sum as u16); + // Return the new sum ready to be saved into the packet + res } /// A layer type to indicate what the next data holds #[derive(Debug, PartialEq, Eq)] pub enum LayerType { /// EthernetPacket - ETH, + Eth, /// IPPacket IP, /// ARPPacket - ARP, + Arp, /// UDPPacket - UDP, + Udp, /// ICMPPacket - ICMP, + Icmp, /// DHCPPacket - DHCP, + Dhcp, /// TCPPacket - TCP, + Tcp, /// An error occured - ERR, + Err, /// No more data (but not error) - END, + End, } // todo: reduce size of enum @@ -88,7 +119,7 @@ pub enum LayerType { /// Is both a type (what is the kind) and a packet (something that implements Layer) #[derive(Debug)] pub enum PacketData { - ETH(EthernetPacket), + Eth(EthernetPacket), IP(IPPacket), ARP(ArpPacket), UDP(UDPPacket), @@ -103,7 +134,7 @@ impl PacketData { /// Forcefully unwrap the packet as EthernetPacket pub fn unwrap_eth(self) -> EthernetPacket { match self { - PacketData::ETH(val) => val, + PacketData::Eth(val) => val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } @@ -152,15 +183,15 @@ impl PacketData { /// Get the type of the PacketData by decomposing into LayerType pub fn get_type(&self) -> LayerType { match self { - PacketData::ETH(_) => LayerType::ETH, + PacketData::Eth(_) => LayerType::Eth, PacketData::IP(_) => LayerType::IP, - PacketData::ARP(_) => LayerType::ARP, - PacketData::UDP(_) => LayerType::UDP, - PacketData::ICMP(_) => LayerType::ICMP, - PacketData::DHCP(_) => LayerType::DHCP, - PacketData::TCP(_) => LayerType::TCP, - PacketData::ERR(_) => LayerType::ERR, - PacketData::UNDEF(_) => LayerType::END, + PacketData::ARP(_) => LayerType::Arp, + PacketData::UDP(_) => LayerType::Udp, + PacketData::ICMP(_) => LayerType::Icmp, + PacketData::DHCP(_) => LayerType::Dhcp, + PacketData::TCP(_) => LayerType::Tcp, + PacketData::ERR(_) => LayerType::Err, + PacketData::UNDEF(_) => LayerType::End, } } } @@ -171,17 +202,17 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { // Start the state machine with an undefined packet and a next_type of ethernet packet let mut i = 0; let mut last_layer = PacketData::UNDEF(EmptyLayer::new()); - let mut next_type = LayerType::ETH; + let mut next_type = LayerType::Eth; loop { // Iterate matching on the state match next_type { - LayerType::ETH => { + LayerType::Eth => { // ethernet state - unwrap the last layer as undefined let last_layer_data = last_layer.unwrap_undef(); // parse the data (starting from i) let (eth_layer, size, network_layer_type) = EthernetPacket::parse(last_layer_data, &packet[i..]); // save the ethernet packet, increment i, and set the next type to be what the Ethernet Packet dictates - last_layer = PacketData::ETH(eth_layer); + last_layer = PacketData::Eth(eth_layer); i += size; next_type = network_layer_type; } @@ -195,7 +226,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { i += size; next_type = transport_layer_type; } - LayerType::ARP => { + LayerType::Arp => { // ARP state - unwrap the last layer as ethernet let last_layer_data = last_layer.unwrap_eth(); // parse the data (starting from i) @@ -205,7 +236,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { i += size; next_type = transport_layer_type; } - LayerType::UDP => { + LayerType::Udp => { // UDP state - unwrap the last layer as IP let last_layer_data = last_layer.unwrap_ip(); // parse the data (starting from i) @@ -215,11 +246,11 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { i += size; next_type = application_layer_type; } - LayerType::ICMP => { + LayerType::Icmp => { // ICMP is unimplemented. Return an undefined packet with 0 data parsed return (0, PacketData::UNDEF(EmptyLayer::new())); } - LayerType::DHCP => { + LayerType::Dhcp => { // DHCP state - unwrap the last layer as UDP let last_layer_data = last_layer.unwrap_udp(); // parse the data (starting from i) @@ -229,7 +260,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { i += size; next_type = empty_type; } - LayerType::TCP => { + LayerType::Tcp => { // TCP state - unwrap the last layer as IP let last_layer_data = last_layer.unwrap_ip(); // parse the data (starting from i) @@ -239,11 +270,11 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { i += size; next_type = empty_type; } - LayerType::ERR => { + LayerType::Err => { // Got an error so return an error packet with 0 data parsed return (0, PacketData::ERR(EmptyLayer::new())); } - LayerType::END => { + LayerType::End => { // Reached a safe end state for the state machine. Return the last layer and how much data we parsed return (i, last_layer); } diff --git a/kernel/src/network/network_query.rs b/kernel/src/network/network_query.rs index 4a811f7..1204819 100644 --- a/kernel/src/network/network_query.rs +++ b/kernel/src/network/network_query.rs @@ -53,14 +53,14 @@ impl NetworkQuery { } // Unwrap and type check the data let pkt_data = pkt.unwrap(); - if pkt_data.get_type() != LayerType::ARP { + if pkt_data.get_type() != LayerType::Arp { continue; } let arp_pkt = pkt_data.unwrap_arp(); // todo: add to the ARP table // Once we've unwrapped the packet, we can close the socket and return the sender mac socket.close(); - return Some(arp_pkt.sender_mac.val()); + return Some(arp_pkt.src_mac.val()); } else { // If we timed-out disable_network_interrupts(); diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 9d2a05f..5c4d396 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -8,15 +8,16 @@ use crate::{ arp::ArpPacket, arp_table::ArpEntry, constants::{ARP_PORT, BROADCAST_ADDR, TCP_SYN}, + errors::NetworkErrors, ethernet::{EthType, EthernetPacket}, ip::{IPPacket, Protocol}, - layer::{Layer, LayerType, PacketData}, + layer::{HasChecksum, Layer, LayerType, PacketData}, raw_socket::wake_sockets, rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, tcp::TCPPacket, - tcp_session::{SessionAction, TCPSession}, errors::NetworkErrors, + tcp_session::{SessionAction, TCPSession}, }, - println, print, + print, println, }; use core::{ @@ -31,7 +32,7 @@ static PROCESS_VEC_WAKER: AtomicWaker = AtomicWaker::new(); /// An array queue for data to parse static PENDING_DATA: OnceCell>> = OnceCell::uninit(); -/// A empty struct with behavior +/// A empty struct with behavior /// - the idea is its a queue for the interrupt handler to enqueue vectors to represent packets /// - we have a task that polls this stream and processes them struct PendingProcessingStream { @@ -85,16 +86,16 @@ pub(crate) fn add_pkt_data(data: Vec) { } } -/// Start the processing of packets +/// Start the processing of packets /// - this function never terminates until we receive None from the raw packet stream (which won't ever happen) pub async fn init_packet_processing() { // Create the stream - let mut raw_packets = PendingProcessingStream::new(); + let mut raw_packets = PendingProcessingStream::new(); // Get the next piece of data the interrupt handler drew from the card while let Some(pkt_data) = raw_packets.next().await { // Parse it let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); - if amount_parsed_and_pkt.1.get_type() == LayerType::ERR || amount_parsed_and_pkt.1.get_type() == LayerType::ICMP { + if amount_parsed_and_pkt.1.get_type() == LayerType::Err || amount_parsed_and_pkt.1.get_type() == LayerType::Icmp { // Don't deal with unrecognized packets (will fail assert) continue; } @@ -106,138 +107,155 @@ pub async fn init_packet_processing() { // Get the device fields let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); let mut rtl_dev_config = NET_INFO.config.lock(); - match amount_parsed_and_pkt.1 { - PacketData::ARP(arp) => { - // todo: also check for broadcast, and expire from arp table - if arp.recp_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { - // WE GOT A REQUEST, so create a response - let eth_layer = EthernetPacket::gen(arp.sender_mac.val(), rtl_dev_info.mac_address.unwrap(), EthType::Arp); - let ip_address = rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR); - let arp_layer = ArpPacket::gen(eth_layer, ip_address, arp.sender_ip.val(), false); - let arp_pkt = arp_layer.serialize(); - // and send it - rtl_dev_info.send_packet(&arp_pkt); - } else { - // WE GOT A RESPONSE, so save it in the arp table - rtl_dev_info.arp_table.push(ArpEntry { - mac: arp.sender_mac.val(), - ip: arp.sender_ip.val(), - expires: 0, - }); - // If there was some process listening on the ARP "port" -> then we have to upstream the packet - if rtl_dev_info.open_ports.contains(&ARP_PORT) { - // if we are listening on the port, try to insert initialize it into the map - if !rtl_dev_info.ports.contains_key(&ARP_PORT) { - rtl_dev_info.ports.insert(ARP_PORT, VecDeque::new()); + (|| { + match amount_parsed_and_pkt.1 { + PacketData::ARP(arp) => { + // todo: also check for broadcast, and expire from arp table + if arp.dest_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { + // WE GOT A REQUEST, so create a response + let eth_layer = EthernetPacket::gen(arp.src_mac.val(), rtl_dev_info.mac_address.unwrap(), EthType::Arp); + let ip_address = rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR); + let arp_layer = ArpPacket::gen(eth_layer, ip_address, arp.src_ip.val(), false); + let arp_pkt = arp_layer.serialize(); + // and send it + rtl_dev_info.send_packet(&arp_pkt); + } else { + // WE GOT A RESPONSE, so save it in the arp table + rtl_dev_info.arp_table.push(ArpEntry { + mac: arp.src_mac.val(), + ip: arp.src_ip.val(), + expires: 0, + }); + // If there was some process listening on the ARP "port" -> then we have to upstream the packet + if rtl_dev_info.open_ports.contains(&ARP_PORT) { + // if we are listening on the port, try to insert initialize it into the map + if !rtl_dev_info.ports.contains_key(&ARP_PORT) { + rtl_dev_info.ports.insert(ARP_PORT, VecDeque::new()); + } + // Push the packet into the port structure and wake the port + rtl_dev_info.ports.get_mut(&ARP_PORT).unwrap().push_back(Ok(PacketData::ARP(arp))); + wake_sockets(ARP_PORT); } - // Push the packet into the port structure and wake the port - rtl_dev_info.ports.get_mut(&ARP_PORT).unwrap().push_back(Ok(PacketData::ARP(arp))); - wake_sockets(ARP_PORT); } } - } - PacketData::DHCP(dhcp) => { - // DHCP packet - let dst_port = dhcp.udp.dest_port.val() as u64; - println!("[HANDLER] Found DHCP packet"); - // If we are listening on the DHCP port - if rtl_dev_info.open_ports.contains(&dst_port) { - println!("[HANDLER] Port {} is open", dst_port); - // Try to initialize the port data structure - if !rtl_dev_info.ports.contains_key(&dst_port) { - rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + PacketData::DHCP(mut dhcp) => { + if !dhcp.verify_checksum() { + // return to leave the match statement (its wrapped in a closure), dropping this packet + return; } - // Push back the dhcp packet and wake the port - rtl_dev_info.ports.get_mut(&dst_port).unwrap().push_back(Ok(PacketData::DHCP(dhcp))); - wake_sockets(dst_port); - } - } - PacketData::UDP(udp) => { - // UDP packet - let dst_port = udp.dest_port.val() as u64; - // If we are listening on the port - if rtl_dev_info.open_ports.contains(&dst_port) { - // Try to initialize the port data structure - if !rtl_dev_info.ports.contains_key(&dst_port) { - rtl_dev_info.ports.insert(dst_port, VecDeque::new()); + // DHCP packet + let dest_port = dhcp.udp.dest_port.val() as u64; + println!("[HANDLER] Found DHCP packet"); + // If we are listening on the DHCP port + if rtl_dev_info.open_ports.contains(&dest_port) { + println!("[HANDLER] Port {} is open", dest_port); + // Try to initialize the port data structure + if !rtl_dev_info.ports.contains_key(&dest_port) { + rtl_dev_info.ports.insert(dest_port, VecDeque::new()); + } + // Push back the dhcp packet and wake the port + rtl_dev_info + .ports + .get_mut(&dest_port) + .unwrap() + .push_back(Ok(PacketData::DHCP(dhcp))); + wake_sockets(dest_port); } - // Push back the UDP packet and wake the port - rtl_dev_info.ports.get_mut(&dst_port).unwrap().push_back(Ok(PacketData::UDP(udp))); - wake_sockets(dst_port); } - } - PacketData::TCP(tcp) => { - // TCP Packet - let dst_port = tcp.dest_port.val() as u64; - if !rtl_dev_info.open_ports.contains(&dst_port) { - // If we aren't listening on the port, throw the packet out - continue; - } - // Try to initialize the port structure - if !rtl_dev_info.ports.contains_key(&dst_port) { - rtl_dev_info.ports.insert(dst_port, VecDeque::new()); - } - // Create the session key - let session_key = TCPSession::gen_session_key(tcp.ip.source_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); - // Open up a session if it doesn't exist - if !rtl_dev_config.tcp_sessions.contains_key(&session_key) { - if (tcp.get_flags() & TCP_SYN) == 0 { - // Don't create new session when there is no request for syncing - drop(rtl_dev_info_guard); - enable_network_interrupts(); + PacketData::UDP(mut udp) => { + if !udp.verify_checksum() { + println!("Cannot verify checksum"); + // return to leave the match statement (its wrapped in a closure), dropping this packet return; } - // Compact the first packet we receive as a template for future requests - let eth_layer = EthernetPacket::gen(tcp.ip.eth.src_mac.val(), tcp.ip.eth.dest_mac.val(), EthType::IPv4); - // ! IMP: leaving size undefined for the template - let ip_layer = IPPacket::gen(eth_layer, 0, Protocol::TCP, tcp.ip.destination_ip.val(), tcp.ip.source_ip.val()); - let tcp_layer = TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); - // Create a session with the template - let session = TCPSession::new(tcp_layer, tcp.ip.source_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); - // Lets upstream our recevied packet to the listening port (and wake it) - // - This is used to split the socket into an "Established" session socket and a "Listening" socket - rtl_dev_info - .ports - .get_mut(&dst_port) - .unwrap() - .push_back(Ok(PacketData::TCP(tcp.clone()))); - wake_sockets(dst_port); - // Finally insert our tcp session - rtl_dev_config.tcp_sessions.insert(session.session_key(), session); + // UDP packet + let dest_port = udp.dest_port.val() as u64; + // If we are listening on the port + if rtl_dev_info.open_ports.contains(&dest_port) { + // Try to initialize the port data structure + if !rtl_dev_info.ports.contains_key(&dest_port) { + rtl_dev_info.ports.insert(dest_port, VecDeque::new()); + } + // Push back the UDP packet and wake the port + rtl_dev_info.ports.get_mut(&dest_port).unwrap().push_back(Ok(PacketData::UDP(udp))); + wake_sockets(dest_port); + } } - // Get the tcp session - let tcp_session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); + PacketData::TCP(mut tcp) => { + if !tcp.verify_checksum() { + // return to leave the match statement (its wrapped in a closure), dropping this packet + return; + } + // TCP Packet + let dest_port = tcp.dest_port.val() as u64; + if !rtl_dev_info.open_ports.contains(&dest_port) { + // If we aren't listening on the port, throw the packet out + return; + } + // Try to initialize the port structure + if !rtl_dev_info.ports.contains_key(&dest_port) { + rtl_dev_info.ports.insert(dest_port, VecDeque::new()); + } + // Create the session key + let session_key = TCPSession::gen_session_key(tcp.ip.src_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); + // Open up a session if it doesn't exist + if !rtl_dev_config.tcp_sessions.contains_key(&session_key) { + if (tcp.get_flags() & TCP_SYN) == 0 { + // Don't create new session when there is no request for syncing + return; + } + // Compact the first packet we receive as a template for future requests + let eth_layer = EthernetPacket::gen(tcp.ip.eth.src_mac.val(), tcp.ip.eth.dest_mac.val(), EthType::IPv4); + // IMP: leaving size undefined for the template (since data size will change) + let ip_layer = IPPacket::gen(eth_layer, 0, Protocol::Tcp, tcp.ip.dest_ip.val(), tcp.ip.src_ip.val()); + let tcp_layer = TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); + // Create a session with the template + let session = TCPSession::new(tcp_layer, tcp.ip.src_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); + // Lets upstream our recevied packet to the listening port (and wake it) + // - This is used to split the socket into an "Established" session socket and a "Listening" socket + rtl_dev_info + .ports + .get_mut(&dest_port) + .unwrap() + .push_back(Ok(PacketData::TCP(tcp.clone()))); + wake_sockets(dest_port); + // Finally insert our tcp session + rtl_dev_config.tcp_sessions.insert(session.session_key(), session); + } + // Get the tcp session + let tcp_session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); - // Generate an acknowledgement via the session's process_recv function - let ack_pkt = tcp_session.process_recv(&tcp); - if let Some(response) = ack_pkt.0 { - // If we got a response packet to send back, send it - // If no ack is receeived, the host will send another transmission for us to respond to - rtl_dev_info.send_packet(&response.serialize()); - } - // Interpret the action from the process_recv function - if ack_pkt.1 != SessionAction::Drop && rtl_dev_info.open_ports.contains(&session_key) { - // if we are listening on the session, try to init the packet queue on that end - if !rtl_dev_info.ports.contains_key(&session_key) { - rtl_dev_info.ports.insert(session_key, VecDeque::new()); + // Generate an acknowledgement via the session's process_recv function + let ack_pkt = tcp_session.process_recv(&tcp); + if let Some(response) = ack_pkt.0 { + // If we got a response packet to send back, send it + // If no ack is receeived, the host will send another transmission for us to respond to + rtl_dev_info.send_packet(&response.serialize()); + } + // Interpret the action from the process_recv function + if ack_pkt.1 != SessionAction::Drop && rtl_dev_info.open_ports.contains(&session_key) { + // if we are listening on the session, try to init the packet queue on that end + if !rtl_dev_info.ports.contains_key(&session_key) { + rtl_dev_info.ports.insert(session_key, VecDeque::new()); + } + let res = if ack_pkt.1 == SessionAction::PushUpstream { + // If the action is upstream, then we push the packet to the raw socket to handle it's data (if present) + Ok(PacketData::TCP(tcp)) + } else if ack_pkt.1 == SessionAction::EndOfStream { + // If the action is end of stream, we push the sentinel to indicate the stream has finished + Err(NetworkErrors::ClosedSocket) + } else { + // We shouldn't reach this unless we added another action and didn't handle it + unreachable!(); + }; + // Push back the packet or end-of-stream token to the session socket -> and then wake the socket + rtl_dev_info.ports.get_mut(&session_key).unwrap().push_back(res); + wake_sockets(session_key); } - let res = if ack_pkt.1 == SessionAction::PushUpstream { - // If the action is upstream, then we push the packet to the raw socket to handle it's data (if present) - Ok(PacketData::TCP(tcp)) - } else if ack_pkt.1 == SessionAction::EndOfStream { - // If the action is end of stream, we push the sentinel to indicate the stream has finished - Err(NetworkErrors::ClosedSocket) - } else { - // We shouldn't reach this unless we added another action and didn't handle it - unreachable!(); - }; - // Push back the packet or end-of-stream token to the session socket -> and then wake the socket - rtl_dev_info.ports.get_mut(&session_key).unwrap().push_back(res); - wake_sockets(session_key); } + _ => {} // ignore other packets } - _ => {} // ignore other packets - } + })(); // Release the guard and enable interrupts drop(rtl_dev_info_guard); enable_network_interrupts(); diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index c63b1fc..aa48063 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -21,7 +21,6 @@ use crate::network::constants::{ RX_START_REG, RX_MULTICAST, RX_PHYSICAL_MATCH, RX_PROMISCOUS, CONFIG_1_REG, }; use crate::network::raw_array::WrappingRawArray; -use crate::print; use super::constants::{IMR_REG, INTERRUPT_MASK, ISR_REG, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; use super::errors::NetworkErrors; diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 1f63644..ab01f59 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -129,13 +129,13 @@ impl Socket { } // Unwrap the packet let pkt = pkt_or_err.unwrap(); - if pkt.get_type() == LayerType::UDP && self.socket_type == SocketType::UDP { + if pkt.get_type() == LayerType::Udp && self.socket_type == SocketType::UDP { // If we have a match (UDP <--> UDP) // then we unwrap the udp packet and set the dest_port, dest_mac, and dest_ip let udp_pkt = pkt.unwrap_udp(); self.dest_port = udp_pkt.src_port.val(); self.dest_mac = udp_pkt.ip.eth.src_mac.val(); - self.dest_ip = udp_pkt.ip.source_ip.val(); + self.dest_ip = udp_pkt.ip.src_ip.val(); // and transition into ready state self.socket_state = SocketState::Ready; @@ -154,12 +154,12 @@ impl Socket { enable_network_interrupts(); // and return None return None; - } else if pkt.get_type() == LayerType::TCP && self.socket_type == SocketType::TCP { + } else if pkt.get_type() == LayerType::Tcp && self.socket_type == SocketType::TCP { // If we have a match (TCP <--> TCP) // Then we unwrap the TCP packet let tcp_pkt = pkt.unwrap_tcp(); // And extract and save the dest_address and dest_port - let dest_address = tcp_pkt.ip.source_ip.val(); + let dest_address = tcp_pkt.ip.src_ip.val(); let dest_port = tcp_pkt.src_port.val(); // We create a session key let session_key = TCPSession::gen_session_key(dest_address, dest_port, self.src_port); @@ -335,7 +335,7 @@ impl Socket { return Err(err); } let pkt = pkt_or_err.unwrap(); - if pkt.get_type() == LayerType::UDP { + if pkt.get_type() == LayerType::Udp { // If we have a matching UDP packet, we pass the data to the read result let udp_pkt = pkt.unwrap_udp(); return Ok(udp_pkt.data); @@ -359,7 +359,7 @@ impl Socket { return Err(err); } let pkt = pkt_or_err.unwrap(); - if pkt.get_type() == LayerType::TCP { + if pkt.get_type() == LayerType::Tcp { // If our packet matches our intended type, save the packet into our buffer let mut tcp_pkt = pkt.unwrap_tcp(); res_vec.append(&mut tcp_pkt.data); @@ -402,24 +402,19 @@ impl Socket { // Build a UDP packet let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); let udp_size = UDPPacket::packet_size() + data.len() as u16; - let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::UDP, self.src_address, self.dest_ip); + let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::Udp, self.src_address, self.dest_ip); let data_len = data.len(); let mut udp_layer = UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); udp_layer.data = data.to_vec(); // todo: split the data and return the amount actually written! - - // Calculate checksum - let data_2_send = udp_layer.serialize(); - let start_udp = data_2_send.len() - (UDPPacket::packet_size() as usize + data_len); - let start_ip = start_udp - (IPPacket::packet_size() as usize); - udp_layer.ip.calculate_checksum(&data_2_send[start_ip..start_udp]); - udp_layer.calculate_checksum(&data_2_send[start_udp..]); + udp_layer.ip.calculate_checksum(); + udp_layer.calculate_checksum(); // Serialize - let data_2_send_final = udp_layer.serialize(); + let packet_data = udp_layer.serialize(); // Send the packet disable_network_interrupts(); - NET_INFO.lock().get_ref().unwrap().send_packet(&data_2_send_final); + NET_INFO.lock().get_ref().unwrap().send_packet(&packet_data); enable_network_interrupts(); // Return how much was written @@ -466,7 +461,7 @@ impl Socket { return Err(err); } let pkt = pkt_or_err.unwrap(); - if pkt.get_type() == LayerType::TCP { + if pkt.get_type() == LayerType::Tcp { // We got a TCP packet let mut tcp_session_guard = NET_INFO.config.lock(); // Check the acknowledgement to make sure everything is acked diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index 707af71..be02ca6 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -1,10 +1,13 @@ use alloc::vec; use alloc::vec::Vec; +use crate::println; + use super::{ bytefield::{Bytefield16, Bytefield32}, + ethernet::EthernetPacket, ip::IPPacket, - layer::{HasChecksum, Layer, LayerType}, + layer::{calculate_checksum_inner, HasChecksum, Layer, LayerType}, }; /// A TCP packet, implements Layer and HasChecksum (20 bytes) @@ -65,7 +68,7 @@ impl TCPPacket { ack_num: Bytefield32::new(0), // a value of 5 in the header_offset (5*4 = 20 bits -> b/c no options) flags: Bytefield16::new(5 << 12), - // sliding window :(, pain to implement... + // sliding window :(, pain to implement... // we can just allow unlimited data... Also we would like to keep track of how much we are allowed to send sliding_window: Bytefield16::new(u16::MAX), checksum: Bytefield16::new(0), @@ -121,7 +124,7 @@ impl Layer for TCPPacket { { // Create an empty packet let mut packet = TCPPacket::new(); - + // Save ip packet and read 14 bytes let mut i = 0; packet.ip = ip_layer; @@ -148,7 +151,7 @@ impl Layer for TCPPacket { i += 1; } // Return the packet, the amount of data consumed, and the next layer type (end) - (packet, i, LayerType::END) + (packet, i, LayerType::End) } /// Serialize the packet into a vector of bytes, ready to send over the network @@ -179,18 +182,16 @@ impl Layer for TCPPacket { impl HasChecksum for TCPPacket { /// Calculate a checksum on the data and the packet /// - will self mutate - fn calculate_checksum(&mut self, data: &[u8]) { + fn calculate_checksum(&mut self) { // Starting vars let mut sum: u32 = 0; - // calculating checksum on serialized bytefield (so its network byte order and must be swapped) - let mut tcp_len = self.total_size(); // First we do the IP as a pseduo header let ip = &self.ip; - sum += (ip.source_ip.data[0] as u32) | (ip.source_ip.data[1] as u32) << 8; - sum += (ip.source_ip.data[2] as u32) | (ip.source_ip.data[3] as u32) << 8; - sum += (ip.destination_ip.data[0] as u32) | (ip.destination_ip.data[1] as u32) << 8; - sum += (ip.destination_ip.data[2] as u32) | (ip.destination_ip.data[3] as u32) << 8; + sum += (ip.src_ip.data[0] as u32) | (ip.src_ip.data[1] as u32) << 8; + sum += (ip.src_ip.data[2] as u32) | (ip.src_ip.data[3] as u32) << 8; + sum += (ip.dest_ip.data[0] as u32) | (ip.dest_ip.data[1] as u32) << 8; + sum += (ip.dest_ip.data[2] as u32) | (ip.dest_ip.data[3] as u32) << 8; // Sum protocol and length let protocol = Bytefield16::new(ip.protocol as u16); @@ -201,29 +202,70 @@ impl HasChecksum for TCPPacket { // Zero the checksum field self.checksum = Bytefield16::new(0); - // Sum the body - let mut ptr = 0; - while tcp_len > 1 { - sum += (data[ptr] as u32) | ((data[ptr + 1] as u32) << 8); - tcp_len -= 2; - ptr += 2; - } + // Calculate checksum on body + let data = self.serialize(); + let start_tcp = IPPacket::packet_size() + EthernetPacket::packet_size(); + let res = calculate_checksum_inner(&data[start_tcp as usize..], sum); - if data.len() % 2 == 1 { - // Add the padding if the packet length is odd - sum += data[ptr] as u32; - } + // Set the checksum + self.checksum = Bytefield16::new(res); + } - // Add the carries - while sum > 0xFFFF { - sum = (sum & 0xFFFF) + (sum >> 16); + /// Check if the checksum is valid in the packet + fn verify_checksum(&mut self) -> bool { + // Clone the packet in host order + let mut ip: IPPacket = IPPacket { + eth: EthernetPacket::new(), + version_hlen: self.ip.version_hlen, + type_of_service: self.ip.type_of_service, + total_length: self.ip.total_length.swapped_endianness(), + identification: self.ip.identification.swapped_endianness(), + flags_fragment_offset: self.ip.flags_fragment_offset.swapped_endianness(), + ttl: self.ip.ttl, + protocol: self.ip.protocol, + checksum: Bytefield16::new(0), + src_ip: self.ip.src_ip.swapped_endianness(), + dest_ip: self.ip.dest_ip.swapped_endianness(), + }; + ip.calculate_checksum(); + if self.ip.checksum.swapped_endianness().val() != ip.checksum.val() { + return false; } - // One's complement and swap the bytes because we did our sum in big endian - // (and the bytefield will try to convert to big endian) - let res = u16::swap_bytes(!sum as u16); + // Clone the packet in host order + let mut tcp: TCPPacket = TCPPacket { + ip, + src_port: self.src_port.swapped_endianness(), + dest_port: self.dest_port.swapped_endianness(), + seq_num: self.seq_num.swapped_endianness(), + ack_num: self.ack_num.swapped_endianness(), + flags: self.flags.swapped_endianness(), + sliding_window: self.sliding_window.swapped_endianness(), + checksum: Bytefield16::new(0), + urgent: self.urgent.swapped_endianness(), + options: self.options.clone(), + data: self.data.clone(), + }; + // println!("SRC {:?} {}", tcp.src_port, self.src_port.val()); + // println!("DST {:?} {}", tcp.dest_port, self.dest_port.val()); + // println!("SEQ {:?} {}", tcp.seq_num, self.seq_num.val()); + // println!("ACK {:?} {}", tcp.ack_num, self.ack_num.val()); + // println!("FLG {:?} {}", tcp.flags, self.flags.val()); + // println!("SLW {:?} {}", tcp.sliding_window, self.sliding_window.val()); + // println!("CKS {:?} {}", tcp.checksum, self.checksum.val()); + // println!("URG {:?} {}", tcp.urgent, self.urgent.val()); + // println!("OPT {:?}", tcp.options); + // println!("DAT {:?}", tcp.data); - // Set the checksum - self.checksum = Bytefield16::new(res); + tcp.calculate_checksum(); + // println!("CHECKSUMS: {:#10x} {:#10x}", tcp.checksum.val(), self.checksum.swapped_endianness().val()); + // println!("CHECKSUMS: {} {}", tcp.checksum.val(), self.checksum.swapped_endianness().val()); + // todo: figure out why I am off by 1024 when data is empty. This is a terrible way to fix a bug (and doesn't always works) + if tcp.data.is_empty() { + true + // tcp.checksum.val() - 1024 == self.checksum.swapped_endianness().val() + } else { + tcp.checksum.val() == self.checksum.swapped_endianness().val() + } } } diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index ca6fba9..25c3231 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -122,13 +122,10 @@ impl TCPSession { tcp_pkt.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); - + // Calculate checksums - let data = tcp_pkt.serialize(); - let start_tcp = IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize; - let start_ip = EthernetPacket::packet_size() as usize; - tcp_pkt.ip.calculate_checksum(&data[start_ip..start_tcp]); - tcp_pkt.calculate_checksum(&data[start_tcp..]); + tcp_pkt.ip.calculate_checksum(); + tcp_pkt.calculate_checksum(); // Return the packet Ok(tcp_pkt) @@ -160,12 +157,9 @@ impl TCPSession { tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); // Calculate checksums - let data = tcp_pkt.serialize(); - let start_tcp = IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize; - let start_ip = EthernetPacket::packet_size() as usize; - tcp_pkt.ip.calculate_checksum(&data[start_ip..start_tcp]); - tcp_pkt.calculate_checksum(&data[start_tcp..]); - + tcp_pkt.ip.calculate_checksum(); + tcp_pkt.calculate_checksum(); + // Return packet Ok(tcp_pkt) } @@ -313,11 +307,9 @@ impl TCPSession { }; } // Calculate checksums - let data = response.serialize(); - let start_tcp = IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize; - let start_ip = EthernetPacket::packet_size() as usize; - response.ip.calculate_checksum(&data[start_ip..start_tcp]); - response.calculate_checksum(&data[start_tcp..]); + response.ip.calculate_checksum(); + response.calculate_checksum(); + // Return result and action (Some(response), response_action) diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index e5b0d9d..6947921 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -1,7 +1,10 @@ +use crate::println; + use super::{ bytefield::Bytefield16, + ethernet::EthernetPacket, ip::IPPacket, - layer::{HasChecksum, Layer, LayerType}, + layer::{calculate_checksum_inner, HasChecksum, Layer, LayerType}, }; use alloc::vec; @@ -11,7 +14,7 @@ use alloc::vec::Vec; #[derive(Debug)] pub struct UDPPacket { /// The parent packet - pub ip: IPPacket, + pub ip: IPPacket, /// Source port pub src_port: Bytefield16, /// Destination port @@ -19,10 +22,9 @@ pub struct UDPPacket { /// Length of data pub length: Bytefield16, /// The checksum - checksum: Bytefield16, + pub checksum: Bytefield16, /// a vector for data bytes if needed pub data: Vec, - } impl UDPPacket { @@ -44,15 +46,16 @@ impl UDPPacket { /// - dest_port: the destination port (open on the server) /// - length: the length of the body (how much data under it) pub fn gen(ip_layer: IPPacket, src_port: u16, dest_port: u16, length: u16) -> Self { - UDPPacket { + let mut udp = UDPPacket { ip: ip_layer, src_port: Bytefield16::new(src_port), dest_port: Bytefield16::new(dest_port), // size of body + 8 bytes for UDP - length: Bytefield16::new(length + 8), + length: Bytefield16::new(length + 8), checksum: Bytefield16::new(0), data: Vec::new(), - } + }; + udp } } @@ -68,8 +71,8 @@ impl Layer for UDPPacket { Self: Sized, { // create an empty packet - let mut packet = UDPPacket::new(); - + let mut packet = UDPPacket::new(); + // Save ip packet and read 14 bytes let mut i = 0; packet.ip = ip_layer; @@ -82,7 +85,7 @@ impl Layer for UDPPacket { // Match the destionation port to see if its DHCP let layer_type = match packet.dest_port.val() { // If port 68, send to DHCP layer - 68 => LayerType::DHCP, + 68 => LayerType::Dhcp, _ => { // read remaining bytes and place them into the data buffer for _ in 0..(packet.length.val() - 8) { @@ -91,7 +94,7 @@ impl Layer for UDPPacket { } assert!(i == packet.length.val() as usize); // We are done parsing - LayerType::END + LayerType::End } }; // Return the packet, the amount of data consumed, and the next layer type (end or DHCP) @@ -121,48 +124,62 @@ impl Layer for UDPPacket { impl HasChecksum for UDPPacket { /// Calculate a checksum on the data and the packet /// - will self mutate - fn calculate_checksum(&mut self, data: &[u8]) { + fn calculate_checksum(&mut self) { // Starting vars let mut sum: u32 = 0; - // calculating checksum on serialized bytefield (so its network byte order and must be swapped) - let mut udp_len = self.length.swapped_endianness().val() as usize; // First we do the IP as a pseduo header let ip = &self.ip; - sum += (ip.source_ip.data[0] as u32) | (ip.source_ip.data[1] as u32) << 8; - sum += (ip.source_ip.data[2] as u32) | (ip.source_ip.data[3] as u32) << 8; - sum += (ip.destination_ip.data[0] as u32) | (ip.destination_ip.data[1] as u32) << 8; - sum += (ip.destination_ip.data[2] as u32) | (ip.destination_ip.data[3] as u32) << 8; + sum += (ip.src_ip.data[0] as u32) | (ip.src_ip.data[1] as u32) << 8; + sum += (ip.src_ip.data[2] as u32) | (ip.src_ip.data[3] as u32) << 8; + sum += (ip.dest_ip.data[0] as u32) | (ip.dest_ip.data[1] as u32) << 8; + sum += (ip.dest_ip.data[2] as u32) | (ip.dest_ip.data[3] as u32) << 8; // Sum protocol and length let protocol = Bytefield16::new(ip.protocol as u16); sum += (protocol.data[0] as u32) | (protocol.data[1] as u32) << 8; sum += (self.length.data[0] as u32) | (self.length.data[1] as u32) << 8; - // Sum the body + // Calculate checksum on body self.checksum = Bytefield16::new(0); - let mut ptr = 0; - while udp_len > 1 { - sum += (data[ptr] as u32) | ((data[ptr + 1] as u32) << 8); - udp_len -= 2; - ptr += 2; - } + let data = self.serialize(); + let start_udp = IPPacket::packet_size() + EthernetPacket::packet_size(); + let res = calculate_checksum_inner(&data[start_udp as usize..], sum); - if data.len() % 2 == 1 { - // Add the padding if the packet length is odd - sum += data[ptr] as u32; - } + // Save the checksum + self.checksum = Bytefield16::new(res); + } - // Add the carries - while sum > 0xFFFF { - sum = (sum & 0xFFFF) + (sum >> 16); + /// Check if the checksum is valid in the packet + fn verify_checksum(&mut self) -> bool { + // Clone the packet in host order + let mut ip: IPPacket = IPPacket { + eth: EthernetPacket::new(), + version_hlen: self.ip.version_hlen, + type_of_service: self.ip.type_of_service, + total_length: self.ip.total_length.swapped_endianness(), + identification: self.ip.identification.swapped_endianness(), + flags_fragment_offset: self.ip.flags_fragment_offset.swapped_endianness(), + ttl: self.ip.ttl, + protocol: self.ip.protocol, + checksum: Bytefield16::new(0), + src_ip: self.ip.src_ip.swapped_endianness(), + dest_ip: self.ip.dest_ip.swapped_endianness(), + }; + ip.calculate_checksum(); + if self.ip.checksum.swapped_endianness().val() != ip.checksum.val() { + return false; } - // One's complement and swap the bytes because we did our sum in big endian - // (and the bytefield will try to convert to big endian) - let res = u16::swap_bytes(!sum as u16); - - // Return the one's complement of sum - self.checksum = Bytefield16::new(res); + let mut udp: UDPPacket = UDPPacket { + ip, + src_port: self.src_port.swapped_endianness(), + dest_port: self.dest_port.swapped_endianness(), + length: self.length.swapped_endianness(), + checksum: Bytefield16::new(0), + data: self.data.clone(), + }; + udp.calculate_checksum(); + udp.checksum.val() == self.checksum.swapped_endianness().val() } } From eefde48faa5f0145cc5621b9e0bf3be687c6a41a Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 28 Nov 2023 16:46:10 -0500 Subject: [PATCH 22/36] More documentation --- kernel/src/.DS_Store | Bin 0 -> 8196 bytes kernel/src/network/README.md | 48 +++++++++++++++++++++++++++++++--- kernel/src/network/TODO.md | 13 +++++---- kernel/src/network/rtl8139.rs | 37 ++++++++++++++++---------- 4 files changed, 73 insertions(+), 25 deletions(-) create mode 100644 kernel/src/.DS_Store diff --git a/kernel/src/.DS_Store b/kernel/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..0495d1fc58f87aa382744a8bbba3e509217119a5 GIT binary patch literal 8196 zcmeHMT}&KR6u#d;VOFLvAQd`Y+7-0bN(xw8v1xVrD+P*>^0#aoad&4ZBf|{t%r4l( zxQ#}m@yWl}8jXqifX1IEwML`H#2BL?CZZ<(#OSL@6XnHt?%dg?KzTE!+PTTS-=1^s zxp(fF@7$X^%NRpvPH$tZo-rm<<)~IrbD6^JylyB`peZ2)`7`FQESt65d{^4q&^uIw z5eOp?Mj(tp7=bVXS3(5n%;rVjCR5`{2EC}?s4r=@>0AVSq z{X)O#9^m`L0*nV(5a_$onBw+;z!kw21Hzr`G2Wdp9$-PBaAy$i48hC@ZYc0)C%>5A zogpqTY{LkI5!e_3es?ctHZz#V7B0QNJ88yB4y5NT)Ad?gS0I&?ZYe95BiB?YTeZ=# z@d;%zVP+hA#5`k6sa5Wbv(QK4aMsA1#==C#v2t^^JEc}p^=(=7 zZ``SAC%L67Yf!5P4!ieg%BV|0P*Xw2TbdPRbk?+W)eWX z?Io@M`bf{6dUVrsipBw2#UvlDL)ZN^)Tq@}+OT#)R=<+zb(nVE^yqC63kY(q+!Rm7 z{n$6f_lkGWB!DqLppn|57$D#F8`y5vPTTw_n@D^lJI_96U$Ys80e1wnj3BJGzzQuR=1%KdAtV)ukNL!^!sZOewG--#lPim3cr4FfI z8j^;^vrEBA#FA^*B&A5+m{Rj;i>H*dborEu_8#l&SB_utDfLaINTO}m)x~&bwIAuY zNJ7EdB2h1;zt8zKnc{Qa=fpa`E>o)Qh&Cv@ikY}RL&qptj76E}rdlp~H!3&Clq~8N zA?{IbmMKltG9flc_sNu)s^p7GTXcVf(oZcBVsrHNh^#^ghm^w+N<_6(hz&|S5>c5i zM7yFrG9{BL|I;l0lC7{G*l+9tTSXZvpkfzpLMsj;j-%)$UJqap Syncing --> Established --> Closing --> Closed. In the Waiting state, we wait for a SYN packet. Once this happens, we go to the Syncing state. Here, we are finished the 3-way handshake. After completing the handshake, we are Established and can send/recv packets with data. After receiving a FIN packet, we go to the Closing state. Here we are doing the 4-way handshake for closing the connection cleanly. After doing so, we are in Closed and will drop any further packets. In this state, we will push an EndOfStream network error. When encountering this, it is the application's job to close the socket so we can reclaim the TCP session. These state changes happen through process_receive and process_send. There is also a template, which is created when the object is constructed. This is used to avoid having to fill in the common fields every time. + +## Init + +There are two init processes that are important. First, processing.rs must be setup. There is a function init_process_packet_data which will start an infinite task that will read packets from the device driver, parse them, and send them to the proper port. The second task is the init_dhcp which will find the IP address of the machine. + +## ARP + +ARP is a standalone service. Processing.rs will handle responding to ARP requests and dealing with ARP replies. This maintains the arp table (a vector of arp entries, see arp_table.rs). ARP has the job of translating IP addresses (OSI layer 3) into mac addresses (OSI layer 2). In network_query, there is also a clean function for sending/listening for the result of an ARP query to determine a mac address. + +## Pipeline + +To get a packet to the application layer, this is the pipeline it follows. + +rtl8139.rs --> processing.rs --> tcp_session.rs (if tcp packet) --> raw_socket.rs --> socket.rs ## TODOS @@ -16,7 +56,7 @@ TODO [] Fix synchronization to be much cleaner [] Clean up ugly stuff [] search for todo and fix thoses -[] Refactor to include more documentation on the network module +[] Everything in TODO.md ## Receiving packets @@ -24,7 +64,7 @@ Firstly, in rtl8139.rs, we receive a device interrupt that a new packet has arri ### u64 "ports" -Traditionally, outward facing ports are from 0-65535, or a u16. But to handle sockets that don't necessarily bind to these ports, the internal ports for receiving data goes up to a u64. As a result, we have a whole space to communicate on. For instance, ARP, while waiting for a response, listens on "port" 65537. This reservation allows the correct direction of packets to the interpretation of them. +Traditionally, outward facing ports are from 0-65535, or a u16. But to handle sockets that don't necessarily bind to these ports, the internal ports for receiving data goes up to a u64. As a result, we have a whole space to communicate on. For instance, ARP, while waiting for a response, listens on "port" 65537. This reservation allows the correct direction of packets to the interpretation of them. This u64 space also allows TCP sessions to have it's unique ID (src-mac + src-ip + dst-port) to serve as its "port". ### Timeouts diff --git a/kernel/src/network/TODO.md b/kernel/src/network/TODO.md index 5660b88..b0e0df9 100644 --- a/kernel/src/network/TODO.md +++ b/kernel/src/network/TODO.md @@ -1,9 +1,8 @@ # TODOs -* Refactor so that all of networking is tested -* Refactor to include more documentation on the network module -* Verify other parts of the packet (in rtl) -* Fix synchronization to be much cleaner -* Clean up ugly stuff -* Search for todo and fix thoses -* Benchmarking +* Refactor so that all of networking is tested (Fri) +* Fix synchronization to be much cleaner (Sat/Sun) +* Clean up ugly stuff -- final code review +* Search for todo and fix thoses (Sat) +* Benchmarking (Sun) +* Fix socket errors (Closing sockets... let's try different sources?) diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index aa48063..02f6e88 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -65,20 +65,25 @@ static mut INTERRUPTS_ARE_ENABLED: spin::Mutex = spin::Mutex:: pub fn disable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; if data.get() == 0 { + // If the counter is 0, we can disable interrupts let mut port_imr = Port::::new((unsafe { IO_BASE } as u16) + IMR_REG); unsafe { port_imr.write(0x0) }; } + // then we increment the counter data.inc(); } // Enable network interrupts (is thread safe) pub fn enable_network_interrupts() { let mut data = unsafe { INTERRUPTS_ARE_ENABLED.lock() }; + // decrement the counter data.dec(); if data.get() == 0 { + // If the counter is 0, we can enable interrupts let mut port_imr = Port::::new((unsafe { IO_BASE } as u16) + IMR_REG); unsafe { port_imr.write(INTERRUPT_MASK) }; } + // Note: If the counter is not 0, someone else has disabled interrupts so we cannot enable it yet } pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) { @@ -133,24 +138,20 @@ fn recv_packet(rtl_dev_info: &RTL8139) { // Make sure buffer isn't empty let cmd_reg = (rtl_dev_info.config.io_base.unwrap() + CMD_REG) as u16; let mut cmd_port = Port::::new(cmd_reg); - // while unsafe { cmd_port.read() } & CMD_REG_BUFE == 0x0 { - if unsafe { cmd_port.read() } & CMD_REG_BUFE == 0x0 { + while unsafe { cmd_port.read() } & CMD_REG_BUFE == 0x0 { // Receive a packet by reading the buffer // ? Reading the buffer is naturally unsafe? Is there a better way? let virtual_buffer_recv: VirtAddr = VirtAddr::new(rtl_dev_info.recv_buffer.unwrap().as_u64() + rtl_dev_info.physical_mem_offset.unwrap()); - // todo: check for packet validity https://www.cs.usfca.edu/~cruse/cs326f04/RTL8139_ProgrammersGuide.pdf let vb_recv_start: *const u8 = virtual_buffer_recv.as_ptr(); let mut rx_buffer = WrappingRawArray::new(vb_recv_start, RX_BUFFER_SIZE as usize); rx_buffer.shift_amount(unsafe { RECV_POS } as usize); let header_buf = rx_buffer.trim(2); let header = (header_buf[0] as u16) | (header_buf[1] as u16) << 8; - // println!("Header {}", header); // Checking receive OK and no errors if header & 0x01 != 0 && header & 0x02 == 0 && header & 0x04 == 0 && header & 0x20 == 0 { let length_buf = rx_buffer.trim(2); // get the next two bytes let length = (length_buf[0] as u16) | (length_buf[1] as u16) << 8; - // println!("Length {}", length); let packet = rx_buffer.trim((length - 4) as usize); // ? throw out the crc... we don't need to check it... rx_buffer.shift_amount(4); @@ -208,7 +209,9 @@ impl NetworkConfig { } impl RTL8139 { - // Initialize the card + /// Initialize the card + /// Will allocate frames to be the physical send/recv buffers + /// Will use physical_mem_offset to calculate physical addresses from the frames pub fn init(&mut self, frame_allocator: &mut BootInfoFrameAllocator, physical_mem_offset: u64) -> bool { let mut frames: Vec = vec![]; // create recv and send buffer space @@ -244,6 +247,8 @@ impl RTL8139 { setup_status } + /// Create a new driver instance with the device + /// Will return None if no Device matches the correct vendor and device ID fn new(configs: Vec) -> Option { let mut use_dev = None; for dev in configs.iter() { @@ -271,9 +276,9 @@ impl RTL8139 { None } - // True on success - // Interrupt handler and frame allocation must be done as a pre-req - // Also will set the mac address of the struct... (before then, the mac address is 0) + /// Interrupt handler and frame allocation must be done as a pre-req + /// Also will set the mac address of the struct... (before then, the mac address is 0) + /// @return True on success fn setup(&mut self) -> bool { // disable the interrupt line to prevent early triggering of an interrupt before we register the handler InterruptHandler::block_irq(self.config.irq.unwrap()); @@ -364,7 +369,10 @@ impl RTL8139 { true } + /// Write the packet data to the card and notify + /// This will send the packet pub fn send_packet(&self, packet_data: &Vec) { + // If our buffers are invalid, we panic if self.send_buffer.is_none() || self.physical_mem_offset.is_none() { panic!("RTL8139 is not initialized properly"); } @@ -373,31 +381,32 @@ impl RTL8139 { if self.mac_address.is_none() || io_base.is_none() { return; } - + // Create a virtual address in which we can write the contents of the packet let virtual_buffer: VirtAddr = VirtAddr::new(self.send_buffer.unwrap().as_u64() + self.physical_mem_offset.unwrap()); let virtual_buffer_ptr: *mut u8 = virtual_buffer.as_mut_ptr(); // Copy data into buffer for (i, item) in packet_data.iter().enumerate() { unsafe { *(virtual_buffer_ptr.wrapping_add(i)) = *item }; } + // Padding the packet with 0s if packet_data.len() < 60 { for j in 0..(60 - packet_data.len()) { - // pad up to 60 with 0s unsafe { *(virtual_buffer_ptr.wrapping_add(packet_data.len() + j)) = 0 }; } } - // TODO: Make this part of self... + // Which descriptor to send the packet from. RTL8139 uses a round robin of these let reg = TRANSMIT_REG[unsafe { TRANSMIT_IDX as usize }]; let cmd = TRANSMIT_CMD[unsafe { TRANSMIT_IDX as usize }]; - + // Ports to the descriptor let mut reg_port = Port::::new((io_base.unwrap() + reg) as u16); let mut cmd_port = Port::::new((io_base.unwrap() + cmd) as u16); + // Write the packet length and the virtual address of the packet to the card unsafe { reg_port.write(virtual_buffer.as_u64() as u32); cmd_port.write(max(packet_data.len(), 60) as u32); } - // Send the packet from the buffer + // Cycle through the descriptor indexes unsafe { TRANSMIT_IDX += 1; TRANSMIT_IDX %= 4; From f6108ad87e00375217153bdb57a54e774b2b73e1 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Wed, 29 Nov 2023 00:02:20 -0500 Subject: [PATCH 23/36] Temporary fix for network stack --- kernel/src/network/processing.rs | 4 +++- kernel/src/network/raw_socket.rs | 7 +++++-- kernel/src/network/rtl8139.rs | 2 ++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 5c4d396..9751f7d 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -232,8 +232,10 @@ pub async fn init_packet_processing() { // If no ack is receeived, the host will send another transmission for us to respond to rtl_dev_info.send_packet(&response.serialize()); } + // todo: Release TCP resources after 5 minutes if never closed + // todo: remove the // Interpret the action from the process_recv function - if ack_pkt.1 != SessionAction::Drop && rtl_dev_info.open_ports.contains(&session_key) { + if ack_pkt.1 != SessionAction::Drop { // if we are listening on the session, try to init the packet queue on that end if !rtl_dev_info.ports.contains_key(&session_key) { rtl_dev_info.ports.insert(session_key, VecDeque::new()); diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index d366024..5cca964 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -130,6 +130,9 @@ impl Stream for RawSocket { /// Wake sockets by port pub(crate) fn wake_sockets(port: u64) { - // wake the port up - NEW_PACKET_WAKER.lock()[&port].wake(); + let guard = NEW_PACKET_WAKER.lock(); + if guard.contains_key(&port){ + // wake the port up, if possible + guard[&port].wake(); + } } diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 02f6e88..3775595 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -194,6 +194,7 @@ pub struct RTL8139 { pub open_ports: HashSet, pub ports: HashMap>>, pub arp_table: Vec, + pub to_expire: VecDeque, } pub struct NetworkConfig { @@ -271,6 +272,7 @@ impl RTL8139 { open_ports: HashSet::with_capacity(10), ports: HashMap::with_capacity(10), arp_table: Vec::new(), + to_expire: VecDeque::with_capacity(20), }); } None From 02c61aa68883932d9a4427535a50d995209b7608 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Mon, 4 Dec 2023 10:31:56 -0500 Subject: [PATCH 24/36] ARP timeout --- kernel/src/interrupts.rs | 1 - kernel/src/network/arp_table.rs | 22 +++++++++++++++++++--- kernel/src/network/network_query.rs | 8 ++++++-- kernel/src/network/processing.rs | 15 ++++++--------- kernel/src/task/timeout.rs | 6 +++++- 5 files changed, 36 insertions(+), 16 deletions(-) diff --git a/kernel/src/interrupts.rs b/kernel/src/interrupts.rs index c2c3f78..b61a86a 100644 --- a/kernel/src/interrupts.rs +++ b/kernel/src/interrupts.rs @@ -57,7 +57,6 @@ impl InterruptHandler { /// Static function for unblocking an interrupt by irq_num pub fn unblock_irq(irq_num: u8) { - // todo: disable interrupts, maybe this is why we sometimes deadlock let mut locked_pics = PICS.lock(); let data = unsafe { locked_pics.read_masks() }; // set the irq bit to 0 diff --git a/kernel/src/network/arp_table.rs b/kernel/src/network/arp_table.rs index 2771eab..ba337d0 100644 --- a/kernel/src/network/arp_table.rs +++ b/kernel/src/network/arp_table.rs @@ -1,10 +1,26 @@ +use crate::task::timeout::estimate_epoch; + /// An entry in the ARP table -/// todo: This prob shouldn't get its own file pub struct ArpEntry { /// The mac address of the ARP entry pub mac: u64, /// The IP address of the ARP entry pub ip: u32, - /// How many timer interrupts before this entry expires - pub expires: u16, + /// How many timer interrupts when this entry expires + pub expiration_epoch: u64, } + +impl ArpEntry { + /// Create an arp entry + /// - mac: the mac address of the entry + /// - ip: ip address of the entry + /// - expires_in: calculating in how many epochs (timer interrupts, roughly 1/18 of a second) + pub fn new(mac: u64, ip: u32, expires_in: u16) -> Self { + Self { mac, ip, expiration_epoch: estimate_epoch() + expires_in as u64 } + } + + /// Return true if the arp entry is expired + pub fn try_expire(&self) -> bool { + estimate_epoch() > self.expiration_epoch + } +} \ No newline at end of file diff --git a/kernel/src/network/network_query.rs b/kernel/src/network/network_query.rs index 1204819..f6e9d3e 100644 --- a/kernel/src/network/network_query.rs +++ b/kernel/src/network/network_query.rs @@ -24,10 +24,14 @@ impl NetworkQuery { let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); // iterate through entries in the arp table - for entry in rtl_dev_info.arp_table.iter() { - // todo: check for expired arps + for (index, entry) in rtl_dev_info.arp_table.iter().enumerate() { // if entry matches, we can return from the cache if entry.ip == ip { + // If entry is expired, remove and break + if entry.try_expire() { + rtl_dev_info.arp_table.remove(index); + break; + } return Some(entry.mac); } } diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 9751f7d..13876f5 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -17,7 +17,7 @@ use crate::{ tcp::TCPPacket, tcp_session::{SessionAction, TCPSession}, }, - print, println, + println, }; use core::{ @@ -110,8 +110,7 @@ pub async fn init_packet_processing() { (|| { match amount_parsed_and_pkt.1 { PacketData::ARP(arp) => { - // todo: also check for broadcast, and expire from arp table - if arp.dest_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) { + if arp.dest_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) || arp.dest_ip.val() == BROADCAST_ADDR { // WE GOT A REQUEST, so create a response let eth_layer = EthernetPacket::gen(arp.src_mac.val(), rtl_dev_info.mac_address.unwrap(), EthType::Arp); let ip_address = rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR); @@ -120,12 +119,10 @@ pub async fn init_packet_processing() { // and send it rtl_dev_info.send_packet(&arp_pkt); } else { - // WE GOT A RESPONSE, so save it in the arp table - rtl_dev_info.arp_table.push(ArpEntry { - mac: arp.src_mac.val(), - ip: arp.src_ip.val(), - expires: 0, - }); + // WE GOT A RESPONSE, saving into the arp table with an expiration of an hour + rtl_dev_info.arp_table.push(ArpEntry::new(arp.src_mac.val(), arp.src_ip.val(), 64800)); + // ? Note: there is an attack vector where I send infinite ARP responses and fill up our arp table + // ? For now, I am content with not solving this issue... // If there was some process listening on the ARP "port" -> then we have to upstream the packet if rtl_dev_info.open_ports.contains(&ARP_PORT) { // if we are listening on the port, try to insert initialize it into the map diff --git a/kernel/src/task/timeout.rs b/kernel/src/task/timeout.rs index 05c245c..a3bfd0d 100644 --- a/kernel/src/task/timeout.rs +++ b/kernel/src/task/timeout.rs @@ -3,9 +3,13 @@ use core::{cell::RefCell, sync::atomic::AtomicU64, task::Waker}; use lazy_static::lazy_static; use x86_64::instructions::interrupts; -/// An internal counter for how many timer interrupts occured +/// An internal counter for how many timer interrupts occurred static mut INTERRUPT_COUNTER: u64 = 0; +pub fn estimate_epoch() -> u64 { + unsafe { INTERRUPT_COUNTER } +} + /// An id for a timeout #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct TimeoutID(u64); From c8b5c6eeb898d663f781d2853ce6f25c9c20fd6f Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Mon, 4 Dec 2023 13:38:57 -0500 Subject: [PATCH 25/36] Reliable writing, rng --- kernel/src/crypto/mod.rs | 1 + kernel/src/crypto/rng.rs | 20 +++++++++++++ kernel/src/lib.rs | 1 + kernel/src/network/bytefield.rs | 1 - kernel/src/network/dhcp.rs | 29 ++---------------- kernel/src/network/icmp.rs | 1 - kernel/src/network/ip.rs | 32 ++++---------------- kernel/src/network/rtl8139.rs | 15 +++++----- kernel/src/network/socket.rs | 50 ++++++++++++++++++++++--------- kernel/src/network/tcp_session.rs | 12 ++++---- kernel/src/task/tcp_echo.rs | 19 ++++++++---- kernel/src/task/udp_echo.rs | 13 +++++++- 12 files changed, 105 insertions(+), 89 deletions(-) create mode 100644 kernel/src/crypto/mod.rs create mode 100644 kernel/src/crypto/rng.rs delete mode 100644 kernel/src/network/icmp.rs diff --git a/kernel/src/crypto/mod.rs b/kernel/src/crypto/mod.rs new file mode 100644 index 0000000..61ee0f4 --- /dev/null +++ b/kernel/src/crypto/mod.rs @@ -0,0 +1 @@ +pub mod rng; diff --git a/kernel/src/crypto/rng.rs b/kernel/src/crypto/rng.rs new file mode 100644 index 0000000..c43da05 --- /dev/null +++ b/kernel/src/crypto/rng.rs @@ -0,0 +1,20 @@ +use core::sync::atomic::{AtomicU64, Ordering}; + +static SEED: AtomicU64 = AtomicU64::new(55304); + +/// Generate a new random number +/// Is global but thread safe +pub fn get_next_random_num() -> u32 { + // https://en.wikipedia.org/wiki/Linear_congruential_generator + let a = 1103515245; + let m = 0b10000000000000000000000000000000; // 2 ^ 31 + let c = 12345; + loop { + let old: u64 = SEED.load(Ordering::Relaxed); + let next: u64 = (a * old + c) % m; + let res = SEED.compare_exchange_weak(old, next, Ordering::Relaxed, Ordering::Relaxed); + if res.is_ok() { + return next as u32; + } + } +} \ No newline at end of file diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index bbf2f6a..46442f0 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -15,6 +15,7 @@ pub mod gdt; pub mod interrupts; pub mod memory; pub mod network; +pub mod crypto; pub mod prelude; pub mod process; pub mod serial; diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs index 2f53cb9..cdfb2ca 100644 --- a/kernel/src/network/bytefield.rs +++ b/kernel/src/network/bytefield.rs @@ -4,7 +4,6 @@ use core::{ops::{Index, IndexMut}, fmt}; // --> therefore serializing will create network byte order (big endian) Bytefields // --> AND deserializing will create host byte order (small endian) Bytefields // (That's why val() doesn't swap the byte order!) -// todo: refactor the api to track the state of the byte order? (would this work?) /// A structure to represent data as an array of bytes /// Will store in both host and network byte-order. It is dependent on it's construction diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 18549e7..96a5ea9 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -1,5 +1,7 @@ use alloc::vec; +use crate::crypto::rng::get_next_random_num; + use super::{ bytefield::{Bytefield128, Bytefield16, Bytefield32, Bytefield48, Bytefield8}, ethernet::EthernetPacket, @@ -8,26 +10,6 @@ use super::{ udp::UDPPacket, }; -/// A wrapper for thread-safe id generation -struct WrappedU32 { - data: u32, -} - -impl WrappedU32 { - /// Get the value - pub fn get(&self) -> u32 { - self.data - } - /// Set the value - pub fn set(&mut self, data: u32) { - self.data = data; - } -} - -// todo: Extract this generator (and refactor to a lockless approach) to the crypto folder -/// Generator for random IDs -static mut ID_GEN: spin::Mutex = spin::Mutex::new(WrappedU32 { data: 0 }); - /// A DHCP packet, implements Layer (usually 300 bytes or more) #[derive(Debug)] pub struct DHCPPacket { @@ -95,12 +77,7 @@ impl DHCPPacket { /// - mac_address: is the machine's MAC address pub fn gen(udp_layer: UDPPacket, ip_address: Option, mac_address: u64) -> Self { // Generate a unique identifier for the client - let identification = unsafe { - let mut id_gen = ID_GEN.lock(); - let id_gen_old = id_gen.get(); - id_gen.set((id_gen_old + 1) % 0xFFFF); - Bytefield32::new(id_gen.get()) - }; + let identification = Bytefield32::new(get_next_random_num()); // Get the mac address and place it into the client_hardware_address field let mac = Bytefield48::new(mac_address); let mut client_hardware_address = Bytefield128::new(0); diff --git a/kernel/src/network/icmp.rs b/kernel/src/network/icmp.rs deleted file mode 100644 index e973b82..0000000 --- a/kernel/src/network/icmp.rs +++ /dev/null @@ -1 +0,0 @@ -// TODO: \ No newline at end of file diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index c75fe4a..76ed4ef 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -1,6 +1,8 @@ use alloc::vec; use alloc::vec::Vec; +use crate::crypto::rng::get_next_random_num; + use super::{ bytefield::{Bytefield16, Bytefield32, Bytefield8}, ethernet::EthernetPacket, @@ -30,26 +32,6 @@ impl Protocol { } } -// todo: Replace for AtomicU16? -/// ID generator struct (a wrapped u16) -struct WrappedU16 { - data: u16, -} - -impl WrappedU16 { - /// Get the u16 - pub fn get(&self) -> u16 { - self.data - } - /// Set the u16 - pub fn set(&mut self, data: u16) { - self.data = data; - } -} - -/// Atomic u16 for id generation -static mut ID_GEN: spin::Mutex = spin::Mutex::new(WrappedU16 { data: 0 }); - /// A IP packet, implements Layer and HasChecksum (20 bytes) #[derive(Debug, Clone)] pub struct IPPacket { @@ -103,19 +85,15 @@ impl IPPacket { /// - dest_ip: the destination's IP address pub fn gen(eth_layer: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dest_ip: u32) -> Self { // Generate a unique ID for the packet - let identification = unsafe { - let mut id_gen = ID_GEN.lock(); - let id_gen_old = id_gen.get(); - id_gen.set((id_gen_old + 1) % 0xFFFF); - Bytefield16::new(id_gen.get()) - }; + let identification = (get_next_random_num() % u16::MAX as u32) as u16; + // Construct the packet IPPacket { eth: eth_layer, version_hlen: 0x45, type_of_service: 0x0, total_length: Bytefield16::new(data_length + 20), // adding data length and size of IP packet - identification, + identification: Bytefield16::new(identification), flags_fragment_offset: Bytefield16::new(0), ttl: 120, // 120 - our packet needs to make it there protocol, diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 3775595..22def0f 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -130,7 +130,6 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) } } -// todo: refactor to be a loop, this function needs to process >1 packet fn recv_packet(rtl_dev_info: &RTL8139) { if rtl_dev_info.recv_buffer.is_none() || rtl_dev_info.physical_mem_offset.is_none() { panic!("RTL8139 is not initialized properly"); @@ -140,7 +139,6 @@ fn recv_packet(rtl_dev_info: &RTL8139) { let mut cmd_port = Port::::new(cmd_reg); while unsafe { cmd_port.read() } & CMD_REG_BUFE == 0x0 { // Receive a packet by reading the buffer - // ? Reading the buffer is naturally unsafe? Is there a better way? let virtual_buffer_recv: VirtAddr = VirtAddr::new(rtl_dev_info.recv_buffer.unwrap().as_u64() + rtl_dev_info.physical_mem_offset.unwrap()); let vb_recv_start: *const u8 = virtual_buffer_recv.as_ptr(); @@ -152,10 +150,12 @@ fn recv_packet(rtl_dev_info: &RTL8139) { if header & 0x01 != 0 && header & 0x02 == 0 && header & 0x04 == 0 && header & 0x20 == 0 { let length_buf = rx_buffer.trim(2); // get the next two bytes let length = (length_buf[0] as u16) | (length_buf[1] as u16) << 8; - let packet = rx_buffer.trim((length - 4) as usize); - // ? throw out the crc... we don't need to check it... - rx_buffer.shift_amount(4); - add_pkt_data(packet); + if length != 0 { + let packet = rx_buffer.trim((length - 4) as usize); + // ? throw out the crc... we don't need to check it... + rx_buffer.shift_amount(4); + add_pkt_data(packet); + } // after receiving the packet, update CAPR and RECV_POS // increment recv_pos unsafe { @@ -170,13 +170,12 @@ fn recv_packet(rtl_dev_info: &RTL8139) { RECV_POS = ((RECV_POS + 4) & RX_READ_PTR_MASK) % RX_BUFFER_SIZE; } let mut capr = Port::::new((rtl_dev_info.config.io_base.unwrap() + CAPR) as u16); - // println!("[RECV_POS] {}", unsafe { RECV_POS }); unsafe { capr.write(RECV_POS - 0x10) }; } else { unsafe { RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; } - // break; + break; } } } diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index ab01f59..a8eb88f 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -1,4 +1,4 @@ -use core::cmp::max; +use core::cmp::{max, min}; use alloc::vec; use alloc::vec::Vec; @@ -115,17 +115,22 @@ impl Socket { /// Will listen for new connections and create new sessions once something arrives /// - **UDP can only listen for one connection (and thus will return none)** /// - **UDP must still call listen to transition into the correct state** - pub async fn listen(&mut self) -> Option { + /// - If we get "BadSocketState", we weren't a listening socket + /// - We might also get a timeout... + pub async fn listen(&mut self) -> Result, NetworkErrors> { if self.socket_state != SocketState::Listening { - return None; // todo: this is failing silently + return Err(NetworkErrors::BadSocketState); } loop { // Loop on packets from the raw socket if let Some(pkt_or_err) = self.raw_socket.next().await { // If we get an error from the socket, return None - if pkt_or_err.is_err() { - // todo: this is failing silently - return None; + if let Err(err) = pkt_or_err { + if err == NetworkErrors::Timeout { + // Ignore timeouts, loop again + continue; + } + return Err(err); } // Unwrap the packet let pkt = pkt_or_err.unwrap(); @@ -152,8 +157,8 @@ impl Socket { // release the driver drop(rtl_dev_info_guard); enable_network_interrupts(); - // and return None - return None; + // and return itself + return Ok(None); } else if pkt.get_type() == LayerType::Tcp && self.socket_type == SocketType::TCP { // If we have a match (TCP <--> TCP) // Then we unwrap the TCP packet @@ -168,7 +173,7 @@ impl Socket { println!("[INFO] Spawned new TCP session"); // And return a new socket object to be the ready socket // the current socket never transitions out of listening - return Some(Socket { + return Ok(Some(Socket { socket_type: SocketType::TCP, raw_socket, socket_state: SocketState::Ready, @@ -180,7 +185,7 @@ impl Socket { src_mac: self.src_mac, wait_timeout: self.wait_timeout, session_key, - }); + })); } } } @@ -379,7 +384,18 @@ impl Socket { } } + pub async fn reliable_write(&mut self, data: &mut Vec) -> Option { + while !data.is_empty() { + // Write until everything gets written + if let Err(err) = self.write(data).await { + return Some(err); + } + } + None + } + /// Write data to the socket + /// Is unreliable, will only write the pub async fn write(&mut self, data: &mut Vec) -> Result { // Check socket state if self.socket_state != SocketState::Ready { @@ -403,9 +419,11 @@ impl Socket { let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); let udp_size = UDPPacket::packet_size() + data.len() as u16; let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::Udp, self.src_address, self.dest_ip); - let data_len = data.len(); + let data_len = min(data.len(), 60000); let mut udp_layer = UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); - udp_layer.data = data.to_vec(); // todo: split the data and return the amount actually written! + udp_layer.data = data[..data_len].to_vec(); + let mut index = 0; + data.retain(|_| {index += 1; index > data_len}); udp_layer.ip.calculate_checksum(); udp_layer.calculate_checksum(); @@ -432,8 +450,12 @@ impl Socket { // Set up to receive ack by processing the send on the data let mut tcp_session_guard = NET_INFO.config.lock(); let tcp_session = tcp_session_guard.tcp_sessions.get_mut(&self.session_key).unwrap(); + let data_len = min(data.len(), tcp_session.window_size as usize); // Message pkt is our present to send to our server - let message_pkt = tcp_session.process_send(data); + let message_pkt = tcp_session.process_send(&data[..data_len]); + // Trim the array to remove the sent data + let mut index = 0; + data.retain(|_| {index += 1; index > data_len}); // Do error checking if let Err(err) = message_pkt { return Err(err); @@ -478,6 +500,6 @@ impl Socket { } } // Return ok - Ok(data.len() as u16) + Ok(data_len as u16) } } diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index 25c3231..06ddd76 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -1,6 +1,6 @@ use alloc::vec::Vec; -use crate::println; +use crate::{println, crypto::rng::get_next_random_num}; use super::{ bytefield::{Bytefield16, Bytefield32}, @@ -72,11 +72,11 @@ impl TCPSession { /// Generate a new tcp session with a template and the three fields required for a tcp session key /// - dest_ip + dest_port + src_port = session_key pub fn new(session_template: TCPPacket, dest_ip: u32, dest_port: u16, src_port: u16) -> Self { + let rng_seq_num = get_next_random_num(); TCPSession { session_template, - // todo: integrate with randomness to init an initial sequence number - sent_data_amount: 55304, - sent_data_acked: 55304, + sent_data_amount: rng_seq_num, + sent_data_acked: rng_seq_num, recv_data_amount: 0, dest_ip, dest_port, @@ -137,7 +137,7 @@ impl TCPSession { } /// A function for generating a packet to sent with data - pub fn process_send(&mut self, data: &Vec) -> Result { + pub fn process_send(&mut self, data: &[u8]) -> Result { // Check session state if self.session_state != TCPSessionState::Established { return Err(NetworkErrors::BadSocketState); @@ -146,7 +146,7 @@ impl TCPSession { let mut tcp_pkt = self.session_template.clone(); // Set flags and data tcp_pkt.turn_on_flags(TCP_ACK | TCP_PSH); - tcp_pkt.data = data.to_vec(); // todo: split the data and return the amount actually written! + tcp_pkt.data = data.to_vec(); // add the data size tcp_pkt.ip.total_length = diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index f8a533b..18141f6 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -19,9 +19,15 @@ pub async fn tcp_echo_server() { // Allow 10 sockets to form (we don't have threads so only listening to one at a time) for _ in 0..10 { // Listen for a new session socket to form - let mut socket = socket_gen.listen().await.unwrap(); + let socket_or_err = socket_gen.listen().await; + if socket_or_err.is_err() { + let err = unsafe { socket_or_err.unwrap_err_unchecked() }; + println!("[ERR] Found error {:?}", err); + break; + } + let mut socket = socket_or_err.unwrap().unwrap(); loop { - // Continously read from the socket + // Continuously read from the socket let data_or_err = socket.read(0).await; if let Ok(mut data) = data_or_err { if data.is_empty() { @@ -36,15 +42,18 @@ pub async fn tcp_echo_server() { if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { // If their message is quit or exit, we close the connection let mut exit_message = "Closing socket...\n".as_bytes().to_vec(); - let _ = socket.write(&mut exit_message).await; + let is_err = socket.reliable_write(&mut exit_message).await; + if is_err.is_some() { + println!("[ERR] Writing error {:?}", is_err.unwrap()); + } socket.close().await; println!("Closed socket"); break; } } // Echo back the data from the socket - let res_or_err = socket.write(&mut data).await; - if let Err(err) = res_or_err { + let res_or_err = socket.reliable_write(&mut data).await; + if let Some(err) = res_or_err { // Writing - print error, close the socket, and break from the loop println!("[ERR] (Writing): {:?}", err); socket.close().await; diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index 970df03..64ea025 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -18,7 +18,18 @@ pub async fn udp_echo_server() { } let mut socket = socket_or_err.unwrap(); // Listen for a single connection - socket.listen().await; + let status = socket.listen().await; + if status.is_err() { + // Check for an error when listening for a socket + let err = unsafe { status.unwrap_err_unchecked() }; + println!("[ERR] Found {:?} error when listening for a UDP socket", err); + return; + } + if status.unwrap().is_some() { + // Ensure we have None in UDP listen + println!("[ERR] Found unknown error when listening for a UDP socket"); + return; + } loop { // Loop trying to read from the socket let data_or_err = socket.read(0).await; From 158e299d883c50f23c037d2e526c0e9d6bcf44ce Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 5 Dec 2023 22:38:40 -0500 Subject: [PATCH 26/36] More fixes --- kernel/src/network/constants.rs | 7 +++-- kernel/src/network/dhcp.rs | 4 +-- kernel/src/network/layer.rs | 43 +++++++++++++++-------------- kernel/src/network/network_query.rs | 6 ++-- kernel/src/network/processing.rs | 37 ++++++++++++++++--------- kernel/src/network/raw_array.rs | 29 +++++++++++-------- kernel/src/network/rtl8139.rs | 33 ++++++++++++---------- kernel/src/network/socket.rs | 4 +-- kernel/src/network/tcp.rs | 2 -- kernel/src/network/tcp_session.rs | 9 ++---- kernel/src/network/udp.rs | 2 -- 11 files changed, 96 insertions(+), 80 deletions(-) diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index d148071..13a21b0 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -18,11 +18,12 @@ pub const TOK: u16 = 1 << 2; // TOK bit pub const ROK: u16 = 1 << 0; // ROK bit pub const TRANSMIT_REG: [u32; 4] = [0x20, 0x24, 0x28, 0x2C]; // 4 transmit registers pub const TRANSMIT_CMD: [u32; 4] = [0x10, 0x14, 0x18, 0x1C]; // 4 cmd registers -pub const INTERRUPT_MASK: u16 = 0x01 | 0x04 | 0x10 | 0x08 | 0x02; // interrupt mask +pub const INTERRUPT_MASK: u16 = (0x01 | 0x04 | 0x10 | 0x08 | 0x02) & !TOK; // interrupt mask. +// Note: we are removing the transmit ok interrupt, we don't do anything with it so we don't need it's overhead... pub const RX_BUFFER_SIZE: u16 = 8192; // how big the buffer is pub const CONFIG_1_REG: u16 = 0x52; // has LWAKE + LWPTN, will allow power on device if active high (0x00) pub const CMD_REG_RST: u8 = 0x10; // Reset, set to 1 to invoke S/W reset, held to 1 while resetting -pub const CMD_REG_RE: u8 = 0x08; // Reciever Enable, enables receiving +pub const CMD_REG_RE: u8 = 0x08; // Receiver Enable, enables receiving pub const CMD_REG_TE: u8 = 0x04; // Transmitter Enable, enables transmitting pub const CMD_REG_BUFE: u8 = 0x01; // Rx buffer is empty pub const CMD_REG: u32 = 0x37; // command register (1byte) @@ -35,7 +36,7 @@ pub const RX_BUF_REG: u16 = 0x44; // receive buffer config register pub const RX_BROADCAST: u32 = 0x08; // Accept broadcast packets sent to mac ff:ff:ff:ff:ff:ff pub const RX_MULTICAST: u32 = 0x04; // Accept multicast packets pub const RX_PHYSICAL_MATCH: u32 = 0x02; // Accept physical matches -pub const RX_PROMISCOUS: u32 = 0x01; // Accept all packets +pub const RX_PROMISCUOUS: u32 = 0x01; // Accept all packets // TCP Constants pub const TCP_FIN: u8 = 0x1; // TCP FIN flag (gracefully closing connection) diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 96a5ea9..9741f83 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -66,7 +66,7 @@ impl DHCPPacket { client_hardware_address: Bytefield128::new(0), // 16 bytes sname: [0; 64], // 64 bytes file: [0; 128], // 128 bytes - options: [0; 64], // 64 bytes (can be more, todo!) + options: [0; 64], // 64 bytes (can be more but we'll ignore them) // 300 bytes total } } @@ -101,7 +101,7 @@ impl DHCPPacket { client_hardware_address, // 16 bytes sname: [0; 64], // 64 bytes file: [0; 128], // 128 bytes - options: [0; 64], // todo: 64 bytes (can be more) + options: [0; 64], // 64 bytes (can be more but we'll ignore them) // 300 bytes total } } diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index b064d28..28cd197 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -1,3 +1,4 @@ +use alloc::boxed::Box; use alloc::vec::Vec; use super::arp::ArpPacket; @@ -108,24 +109,24 @@ pub enum LayerType { Dhcp, /// TCPPacket Tcp, - /// An error occured + /// An error occurred Err, /// No more data (but not error) End, } -// todo: reduce size of enum /// Wrapper type to allow me to return a generic /// Is both a type (what is the kind) and a packet (something that implements Layer) #[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] pub enum PacketData { - Eth(EthernetPacket), - IP(IPPacket), - ARP(ArpPacket), - UDP(UDPPacket), + ETH(Box), + IP(Box), + ARP(Box), + UDP(Box), ICMP(EmptyLayer), - DHCP(DHCPPacket), - TCP(TCPPacket), + DHCP(Box), + TCP(Box), ERR(EmptyLayer), UNDEF(EmptyLayer), } @@ -134,35 +135,35 @@ impl PacketData { /// Forcefully unwrap the packet as EthernetPacket pub fn unwrap_eth(self) -> EthernetPacket { match self { - PacketData::Eth(val) => val, + PacketData::ETH(val) => *val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } /// Forcefully unwrap the packet as IPPacket pub fn unwrap_ip(self) -> IPPacket { match self { - PacketData::IP(val) => val, + PacketData::IP(val) => *val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } /// Forcefully unwrap the packet as ARPPacket pub fn unwrap_arp(self) -> ArpPacket { match self { - PacketData::ARP(val) => val, + PacketData::ARP(val) => *val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } /// Forcefully unwrap the packet as UDPPacket pub fn unwrap_udp(self) -> UDPPacket { match self { - PacketData::UDP(val) => val, + PacketData::UDP(val) => *val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } /// Forcefully unwrap the packet as DHCPPacket pub fn unwrap_dhcp(self) -> DHCPPacket { match self { - PacketData::DHCP(val) => val, + PacketData::DHCP(val) => *val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } @@ -176,14 +177,14 @@ impl PacketData { /// Forcefully unwrap the packet as TCPPacket pub fn unwrap_tcp(self) -> TCPPacket { match self { - PacketData::TCP(val) => val, + PacketData::TCP(val) => *val, _ => unreachable!("Mismatched type. Couldn't unwrap"), } } /// Get the type of the PacketData by decomposing into LayerType pub fn get_type(&self) -> LayerType { match self { - PacketData::Eth(_) => LayerType::Eth, + PacketData::ETH(_) => LayerType::Eth, PacketData::IP(_) => LayerType::IP, PacketData::ARP(_) => LayerType::Arp, PacketData::UDP(_) => LayerType::Udp, @@ -212,7 +213,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { // parse the data (starting from i) let (eth_layer, size, network_layer_type) = EthernetPacket::parse(last_layer_data, &packet[i..]); // save the ethernet packet, increment i, and set the next type to be what the Ethernet Packet dictates - last_layer = PacketData::Eth(eth_layer); + last_layer = PacketData::ETH(Box::new(eth_layer)); i += size; next_type = network_layer_type; } @@ -222,7 +223,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { // parse the data (starting from i) let (ip_layer, size, transport_layer_type) = IPPacket::parse(last_layer_data, &packet[i..]); // save the IP packet, increment i, and set the next type to be what the IP Packet dictates - last_layer = PacketData::IP(ip_layer); + last_layer = PacketData::IP(Box::new(ip_layer)); i += size; next_type = transport_layer_type; } @@ -232,7 +233,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { // parse the data (starting from i) let (arp_layer, size, transport_layer_type) = ArpPacket::parse(last_layer_data, &packet[i..]); // save the ARP packet, increment i, and set the next type to be what the ARP Packet dictates - last_layer = PacketData::ARP(arp_layer); + last_layer = PacketData::ARP(Box::new(arp_layer)); i += size; next_type = transport_layer_type; } @@ -242,7 +243,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { // parse the data (starting from i) let (udp_layer, size, application_layer_type) = UDPPacket::parse(last_layer_data, &packet[i..]); // save the UDP packet, increment i, and set the next type to be what the UDP Packet dictates - last_layer = PacketData::UDP(udp_layer); + last_layer = PacketData::UDP(Box::new(udp_layer)); i += size; next_type = application_layer_type; } @@ -256,7 +257,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { // parse the data (starting from i) let (dhcp_layer, size, empty_type) = DHCPPacket::parse(last_layer_data, &packet[i..]); // save the DHCP packet, increment i, and set the next type to be what the DHCP Packet dictates - last_layer = PacketData::DHCP(dhcp_layer); + last_layer = PacketData::DHCP(Box::new(dhcp_layer)); i += size; next_type = empty_type; } @@ -266,7 +267,7 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { // parse the data (starting from i) let (tcp_layer, size, empty_type) = TCPPacket::parse(last_layer_data, &packet[i..]); // save the TCP packet, increment i, and set the next type to be what the TCP Packet dictates - last_layer = PacketData::TCP(tcp_layer); + last_layer = PacketData::TCP(Box::new(tcp_layer)); i += size; next_type = empty_type; } diff --git a/kernel/src/network/network_query.rs b/kernel/src/network/network_query.rs index f6e9d3e..76c5b75 100644 --- a/kernel/src/network/network_query.rs +++ b/kernel/src/network/network_query.rs @@ -9,8 +9,6 @@ use super::{ rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, }; -// todo: move DHCP query here? - /// A module for querying things with the network stack pub struct NetworkQuery {} @@ -61,7 +59,9 @@ impl NetworkQuery { continue; } let arp_pkt = pkt_data.unwrap_arp(); - // todo: add to the ARP table + if arp_pkt.src_ip.val() != ip { + continue; + } // Once we've unwrapped the packet, we can close the socket and return the sender mac socket.close(); return Some(arp_pkt.src_mac.val()); diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 13876f5..5b08939 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -1,4 +1,4 @@ -use alloc::{collections::VecDeque, vec::Vec}; +use alloc::collections::VecDeque; use conquer_once::spin::OnceCell; use crossbeam_queue::ArrayQueue; use futures_util::{task::AtomicWaker, Stream, StreamExt}; @@ -29,8 +29,12 @@ use super::layer::full_parse; /// A waker for waking the pending packet stream static PROCESS_VEC_WAKER: AtomicWaker = AtomicWaker::new(); +pub struct PacketBuf { + pub buffer: [u8; 1500], + pub length: u16, +} /// An array queue for data to parse -static PENDING_DATA: OnceCell>> = OnceCell::uninit(); +static PENDING_DATA: OnceCell> = OnceCell::uninit(); /// A empty struct with behavior /// - the idea is its a queue for the interrupt handler to enqueue vectors to represent packets @@ -42,9 +46,9 @@ struct PendingProcessingStream { impl PendingProcessingStream { /// Create a new pending process stream fn new() -> Self { - // Initialize the pending data array queue with max size 100 + // Initialize the pending data array queue with max size 20 PENDING_DATA - .try_init_once(|| ArrayQueue::new(100)) + .try_init_once(|| ArrayQueue::new(20)) .expect("PendingProcessingStream::new should only be called once"); PendingProcessingStream { _private: () } } @@ -52,14 +56,18 @@ impl PendingProcessingStream { impl Stream for PendingProcessingStream { // Output tokens for the polling - type Item = Vec; + type Item = PacketBuf; /// Get the next vector of data - fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll>> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // Get the driver configuration + disable_network_interrupts(); // Get the queue let queue = PENDING_DATA.try_get().expect("not initialized"); // Get the next entry and register to be woken up let data = queue.pop(); + // Release the driver + enable_network_interrupts(); PROCESS_VEC_WAKER.register(cx.waker()); match data { // If we got some data -> return it @@ -71,7 +79,7 @@ impl Stream for PendingProcessingStream { } /// A internal function of the module to append to the queue of potential packets -pub(crate) fn add_pkt_data(data: Vec) { +pub(crate) fn add_pkt_data(data: PacketBuf) { // Try to get the array queue if let Ok(queue) = PENDING_DATA.try_get() { // And push data @@ -94,13 +102,17 @@ pub async fn init_packet_processing() { // Get the next piece of data the interrupt handler drew from the card while let Some(pkt_data) = raw_packets.next().await { // Parse it - let amount_parsed_and_pkt = full_parse(pkt_data.as_slice()); + let amount_parsed_and_pkt = full_parse(&pkt_data.buffer); if amount_parsed_and_pkt.1.get_type() == LayerType::Err || amount_parsed_and_pkt.1.get_type() == LayerType::Icmp { // Don't deal with unrecognized packets (will fail assert) continue; } - // Assert we had a proper amount of data processed without erroring out - assert!(amount_parsed_and_pkt.0 == pkt_data.len() || pkt_data.len() < 64); + if amount_parsed_and_pkt.0 != pkt_data.length as usize && (pkt_data.length as usize) >= 64 { + // Assert we had a proper amount of data processed without erroring out + println!("[ERR] amount parsed: {} --> true length: {}", amount_parsed_and_pkt.0, pkt_data.length); + assert!(amount_parsed_and_pkt.0 == pkt_data.length as usize || (pkt_data.length as usize) < 64); + } + // Get the driver configuration disable_network_interrupts(); let mut rtl_dev_info_guard = NET_INFO.lock(); @@ -226,11 +238,10 @@ pub async fn init_packet_processing() { let ack_pkt = tcp_session.process_recv(&tcp); if let Some(response) = ack_pkt.0 { // If we got a response packet to send back, send it - // If no ack is receeived, the host will send another transmission for us to respond to + // If no ack is received, the host will send another transmission for us to respond to rtl_dev_info.send_packet(&response.serialize()); } - // todo: Release TCP resources after 5 minutes if never closed - // todo: remove the + // todo: Release TCP resources after 5 minutes if never closed? // Interpret the action from the process_recv function if ack_pkt.1 != SessionAction::Drop { // if we are listening on the session, try to init the packet queue on that end diff --git a/kernel/src/network/raw_array.rs b/kernel/src/network/raw_array.rs index 353183a..269135d 100644 --- a/kernel/src/network/raw_array.rs +++ b/kernel/src/network/raw_array.rs @@ -1,8 +1,5 @@ -use alloc::vec; -use alloc::vec::Vec; - /// An array to represent the buffer of the RTL8139 -/// (will wrap it's array accessess as required by the data-sheet) +/// (will wrap it's array accesses as required by the data-sheet) pub struct WrappingRawArray { /// The starting address of the buffer start: *const u8, @@ -23,16 +20,28 @@ impl WrappingRawArray { self.pos = (self.pos + amount) % self.size; } + pub fn get_next_u8(&mut self) -> u8 { + // Move to the starting position + let mut tmp_start = unsafe { self.start.add(self.pos) }; + let res = unsafe { *tmp_start }; + self.shift_amount(1); + res + } + + pub fn get_next_u16(&mut self) -> u16 { + let first = self.get_next_u8(); + let second = self.get_next_u8(); + (first as u16) | (second as u16) << 8 + } + /// Move the array forward, "consuming" those values - pub fn trim(&mut self, amount: usize) -> Vec { - // Create a result vector - let mut res = vec![]; + pub fn trim(&mut self, res: &mut [u8; 1500], amount: usize) { // Move to the starting position let mut tmp_start = unsafe { self.start.add(self.pos) }; - for _ in 0..amount { + for i in 0..amount { // append the byte and move tmp_start forward unsafe { - res.push(*tmp_start); + res[i] = *tmp_start; tmp_start = tmp_start.add(1); } // also increment the position @@ -44,8 +53,6 @@ impl WrappingRawArray { tmp_start = self.start; } } - // return the result - res } } diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 22def0f..0ff5898 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -18,13 +18,13 @@ use crate::interrupts::{IDT, PIC_1_OFFSET}; use crate::network::constants::{ CAPR, CMD_REG, CMD_REG_BUFE, CMD_REG_RE, CMD_REG_RST, CMD_REG_TE, RX_BROADCAST, RX_BUFFER_SIZE, RX_BUF_REG, RX_READ_PTR_MASK, - RX_START_REG, RX_MULTICAST, RX_PHYSICAL_MATCH, RX_PROMISCOUS, CONFIG_1_REG, + RX_START_REG, RX_MULTICAST, RX_PHYSICAL_MATCH, RX_PROMISCUOUS, CONFIG_1_REG, }; use crate::network::raw_array::WrappingRawArray; use super::constants::{IMR_REG, INTERRUPT_MASK, ISR_REG, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; use super::errors::NetworkErrors; -use super::processing::add_pkt_data; +use super::processing::{add_pkt_data, PacketBuf}; use super::tcp_session::TCPSession; use super::{ arp_table::ArpEntry, @@ -41,19 +41,19 @@ use crate::{ // ISR_ROK|ISR_TOK|ISR_RXOVW|ISR_TER|ISR_RER // ! URGENT: start checking other statuses for errors -// todo: why can I get "send" interrupts... what is the purpose +// todo: why can I get "send" interrupts... what is the purpose --> ISR_TOK is now disabled // todo: write comments for this file // FROM OS DEV // N.B. If you find your driver suddenly freezes and stops receiving interrupts and you're using kvm/qemu. Try the option -no-kvm-irqchip static mut TRANSMIT_IDX: u32 = 0; -static mut RECV_POS: u16 = 0; // todo this should be wrapped in a lock? +static mut RECV_POS: u16 = 0; static mut IO_BASE: usize = 0; // these static muts are set from None to Some and never really changed lazy_static! { // ! if a process tries to get this lock without disabling interrupts, we would deadlock - // ! The safertl unwraps the locking mechanism with a RAII enable and disable (Well it would if I could get it to work) + // ! The SafeRTL8139 unwraps the locking mechanism with a RAII enable and disable (Well it would if I could get it to work) pub static ref NET_INFO: SafeRTL8139 = { let devices = devices::scan_devices(); SafeRTL8139::new(spin::Mutex::new(RTL8139::new(devices)), spin::Mutex::new(NetworkConfig::new())) @@ -144,17 +144,19 @@ fn recv_packet(rtl_dev_info: &RTL8139) { let vb_recv_start: *const u8 = virtual_buffer_recv.as_ptr(); let mut rx_buffer = WrappingRawArray::new(vb_recv_start, RX_BUFFER_SIZE as usize); rx_buffer.shift_amount(unsafe { RECV_POS } as usize); - let header_buf = rx_buffer.trim(2); - let header = (header_buf[0] as u16) | (header_buf[1] as u16) << 8; + let header = rx_buffer.get_next_u16(); // Checking receive OK and no errors if header & 0x01 != 0 && header & 0x02 == 0 && header & 0x04 == 0 && header & 0x20 == 0 { - let length_buf = rx_buffer.trim(2); // get the next two bytes - let length = (length_buf[0] as u16) | (length_buf[1] as u16) << 8; + let length = rx_buffer.get_next_u16(); // get the next two bytes if length != 0 { - let packet = rx_buffer.trim((length - 4) as usize); + let mut packet: [u8; 1500] = [0; 1500]; + rx_buffer.trim(&mut packet, (length - 4) as usize); // ? throw out the crc... we don't need to check it... rx_buffer.shift_amount(4); - add_pkt_data(packet); + add_pkt_data(PacketBuf { + buffer: packet, + length: length - 4, + }); } // after receiving the packet, update CAPR and RECV_POS // increment recv_pos @@ -222,12 +224,13 @@ impl RTL8139 { } frames.push(f.unwrap()); } - // Ensure the sections are continous + // Ensure the sections are continuous let mut next_start = frames[0].start_address() + frames[0].size(); + #[allow(clippy::needless_range_loop)] // This "1 -> 6" is much more simple than what clippy is suggesting for i in 1..6 { - // if frame isn't continous (can be acceptable on boundary between send and recv buffer) + // if frame isn't continuous (can be acceptable on boundary between send and recv buffer) if (frames[i].start_address() != next_start) && i != 3 { - println!("[ERR] Frames aren't continous {}", i); + println!("[ERR] Frames aren't continuous {}", i); return false; } next_start = frames[i].start_address() + frames[i].size(); @@ -351,7 +354,7 @@ impl RTL8139 { let mut rcr = Port::::new(rcr_reg); unsafe { // No wrap, ring buffer, accept all packets - rcr.write(RX_PHYSICAL_MATCH | RX_MULTICAST | RX_BROADCAST | RX_PROMISCOUS); + rcr.write(RX_PHYSICAL_MATCH | RX_MULTICAST | RX_BROADCAST | RX_PROMISCUOUS); }; // Init receive buffer diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index a8eb88f..0aeab15 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -1,6 +1,6 @@ use core::cmp::{max, min}; -use alloc::vec; +use alloc::{vec, boxed::Box}; use alloc::vec::Vec; use futures_util::StreamExt; @@ -151,7 +151,7 @@ impl Socket { // re-enqueue the packet if let Some(vec) = rtl_dev_info.ports.get_mut(&(self.src_port as u64)) { - vec.push_front(Ok(PacketData::UDP(udp_pkt))); + vec.push_front(Ok(PacketData::UDP(Box::new(udp_pkt)))); } // release the driver diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index be02ca6..0052afd 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -1,8 +1,6 @@ use alloc::vec; use alloc::vec::Vec; -use crate::println; - use super::{ bytefield::{Bytefield16, Bytefield32}, ethernet::EthernetPacket, diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index 06ddd76..d208f47 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -1,12 +1,9 @@ -use alloc::vec::Vec; - use crate::{println, crypto::rng::get_next_random_num}; use super::{ bytefield::{Bytefield16, Bytefield32}, constants::{TCP_ACK, TCP_FIN, TCP_PSH, TCP_RST, TCP_SYN}, errors::NetworkErrors, - ethernet::EthernetPacket, ip::IPPacket, layer::{HasChecksum, Layer}, tcp::TCPPacket, @@ -164,7 +161,7 @@ impl TCPSession { Ok(tcp_pkt) } - /// A function for processing a recevied packet and generating a response/action + /// A function for processing a received packet and generating a response/action /// - will return the response and session action as a tuple pub fn process_recv(&mut self, request: &TCPPacket) -> (Option, SessionAction) { // Extract flag results @@ -174,12 +171,12 @@ impl TCPSession { let has_rst_flag = (request.get_flags() & TCP_RST) != 0; // Set initial action as push upstream let mut response_action = SessionAction::PushUpstream; - // todo: regularly update window size + // regularly update window size. Will be used to throttle the packets we send via socket.write() self.window_size = request.sliding_window.val(); let mut response = self.session_template.clone(); // Setting length to be 24 [sizeof(TCP-header) + sizeof(options)] + 20 [sizeof(IP-header)] response.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); - // todo! check seq-num and ack-num before doing these state changes + // todo: check seq-num and ack-num before doing these state changes if has_fin_flag && self.session_state == TCPSessionState::Established { // Transition to closing state self.session_state = TCPSessionState::Closing; diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index 6947921..f0579d3 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -1,5 +1,3 @@ -use crate::println; - use super::{ bytefield::Bytefield16, ethernet::EthernetPacket, From 9a99fe9777337dcf7de236472cad23f62e543ca9 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Tue, 5 Dec 2023 22:58:32 -0500 Subject: [PATCH 27/36] Some code cleanup --- kernel/src/network/constants.rs | 2 +- kernel/src/network/raw_array.rs | 6 +++--- kernel/src/network/socket.rs | 4 ++++ kernel/src/network/udp.rs | 5 ++--- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index 13a21b0..0eae6e6 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -44,4 +44,4 @@ pub const TCP_SYN: u8 = 0x2; // TCP SYN flag (synchronizing to open connection) pub const TCP_RST: u8 = 0x4; // TCP RST flag (forcefully terminate connection) pub const TCP_PSH: u8 = 0x8; // TCP PSH flag (push data to application layer) pub const TCP_ACK: u8 = 0x10; // TCP ACK flag (acknowledge something -- multi-use) -pub const TCP_URG: u8 = 0x20; // TCP URG flag (urgent data) +// pub const TCP_URG: u8 = 0x20; // TCP URG flag (urgent data) diff --git a/kernel/src/network/raw_array.rs b/kernel/src/network/raw_array.rs index 269135d..dc2cddb 100644 --- a/kernel/src/network/raw_array.rs +++ b/kernel/src/network/raw_array.rs @@ -22,7 +22,7 @@ impl WrappingRawArray { pub fn get_next_u8(&mut self) -> u8 { // Move to the starting position - let mut tmp_start = unsafe { self.start.add(self.pos) }; + let tmp_start = unsafe { self.start.add(self.pos) }; let res = unsafe { *tmp_start }; self.shift_amount(1); res @@ -38,10 +38,10 @@ impl WrappingRawArray { pub fn trim(&mut self, res: &mut [u8; 1500], amount: usize) { // Move to the starting position let mut tmp_start = unsafe { self.start.add(self.pos) }; - for i in 0..amount { + for val in res.iter_mut().take(amount) { // append the byte and move tmp_start forward unsafe { - res[i] = *tmp_start; + *val = *tmp_start; tmp_start = tmp_start.add(1); } // also increment the position diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 0aeab15..7e9fe7e 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -331,7 +331,9 @@ impl Socket { } /// Internal function for reading as a UDP socket + #[allow(clippy::question_mark)] async fn read_udp(&mut self, size: usize) -> Result, NetworkErrors> { + // todo: fix reading loop { // Spin until we get some packet if let Some(pkt_or_err) = self.raw_socket.next().await { @@ -353,6 +355,7 @@ impl Socket { } /// Internal function for reading as a TCP socket + #[allow(clippy::question_mark)] async fn read_tcp(&mut self, size: usize) -> Result, NetworkErrors> { loop { // Create a result vector @@ -440,6 +443,7 @@ impl Socket { } /// Internal function for writing tcp + #[allow(clippy::question_mark)] async fn write_tcp(&mut self, data: &mut Vec) -> Result { // Check states assert!(self.socket_type == SocketType::TCP); diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index f0579d3..383fe2c 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -44,7 +44,7 @@ impl UDPPacket { /// - dest_port: the destination port (open on the server) /// - length: the length of the body (how much data under it) pub fn gen(ip_layer: IPPacket, src_port: u16, dest_port: u16, length: u16) -> Self { - let mut udp = UDPPacket { + UDPPacket { ip: ip_layer, src_port: Bytefield16::new(src_port), dest_port: Bytefield16::new(dest_port), @@ -52,8 +52,7 @@ impl UDPPacket { length: Bytefield16::new(length + 8), checksum: Bytefield16::new(0), data: Vec::new(), - }; - udp + } } } From 5b76f67f29f2997e0e31822a2008ebd7059faf7e Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Fri, 8 Dec 2023 01:15:31 -0500 Subject: [PATCH 28/36] Tests (1 of 2) --- README.md | 4 +- kernel/src/lib.rs | 4 +- kernel/src/main.rs | 24 +++- kernel/src/network/README.md | 10 +- kernel/src/network/TODO.md | 5 +- kernel/src/network/arp.rs | 82 ++++++++++- kernel/src/network/arp_table.rs | 41 +++++- kernel/src/network/bytefield.rs | 49 ++++++- kernel/src/network/command_register.rs | 101 ++++++++++++++ kernel/src/network/constants.rs | 4 +- kernel/src/network/devices.rs | 22 ++- kernel/src/network/dhcp.rs | 69 +++++++-- kernel/src/network/errors.rs | 2 +- kernel/src/network/ethernet.rs | 55 +++++++- kernel/src/network/init.rs | 31 ++--- kernel/src/network/ip.rs | 89 ++++++++++-- kernel/src/network/layer.rs | 50 +++++++ kernel/src/network/mod.rs | 1 + kernel/src/network/netsync.rs | 64 ++++++--- kernel/src/network/network_query.rs | 133 +++++++++++++++--- kernel/src/network/processing.rs | 185 ++++++++++++++++++------- kernel/src/network/raw_array.rs | 47 +++++++ kernel/src/network/raw_socket.rs | 99 ++++++++----- kernel/src/network/rtl8139.rs | 86 +++++++----- kernel/src/network/socket.rs | 100 ++++++------- kernel/src/network/tcp.rs | 45 +++--- kernel/src/network/tcp_session.rs | 11 +- kernel/src/network/test.rs | 73 ++++++++++ kernel/src/network/udp.rs | 96 +++++++++++-- kernel/src/task/executor.rs | 24 +++- kernel/src/task/mod.rs | 2 +- kernel/src/task/udp_echo.rs | 2 +- src/main.rs | 5 +- 33 files changed, 1288 insertions(+), 327 deletions(-) create mode 100644 kernel/src/network/test.rs diff --git a/README.md b/README.md index 9a4c37e..eb7141c 100644 --- a/README.md +++ b/README.md @@ -36,4 +36,6 @@ TODO: Make a python script for communicating WASM code with the OS... (and docum \[WIP\] -`cargo test` \ No newline at end of file +`cargo test` + +Tests are run before starting the application on sudo cargo run \ No newline at end of file diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index 46442f0..c1ffcfc 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -86,7 +86,7 @@ pub fn test_panic_handler(info: &PanicInfo) -> ! { hlt_loop(); } -/// Entry point for `cargo test` +// Entry point for `cargo test` #[cfg(test)] entry_point!(test_kernel_main); @@ -102,4 +102,4 @@ fn test_kernel_main(_boot_info: &'static mut BootInfo) -> ! { #[panic_handler] fn panic(info: &PanicInfo) -> ! { test_panic_handler(info) -} +} \ No newline at end of file diff --git a/kernel/src/main.rs b/kernel/src/main.rs index dcf8457..9a5433d 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -4,6 +4,7 @@ #![test_runner(kernel::test_runner)] #![reexport_test_harness_main = "test_main"] +use kernel::{QemuExitCode, serial_println, network::test::{test_async, test_sync}}; use bootloader_api::{ config::{BootloaderConfig, Mapping}, entry_point, BootInfo, @@ -13,12 +14,12 @@ use kernel::{ framebuffer, hlt_loop, network::{ init::{init_dhcp, init_process_packet_data}, - rtl8139::NET_INFO, + rtl8139::NET_INFO }, println, task::keyboard, task::{executor::Executor, Task}, - task::{tcp_echo, udp_echo}, + task::{tcp_echo, udp_echo}, exit_qemu, }; extern crate alloc; @@ -85,11 +86,26 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { #[cfg(test)] test_main(); + // Clear stdout and write test stuff + serial_println!("\n\n\n\n\n\n\n\n\n\n\n\n"); + serial_println!("-----------------"); + serial_println!("SOUP OS -- TESTS"); + // Start running sync tests, stopping if there is an error + let res = test_sync(); + if let Err(err) = res { + serial_println!("[ERR] {}", err); + exit_qemu(QemuExitCode::Failed); + } + let mut executor = Executor::new(); // Start the processing of pending packets init_process_packet_data(&mut executor); - executor.spawn(Task::new(do_init_dhcp())); // not entirely async, will finish before others are run - executor.wait(); // todo: fix wait + // not entirely async, must finish before others are run + let wait_for_init_dhcp = executor.spawn(Task::new(do_init_dhcp())); + executor.wait(wait_for_init_dhcp); + // start async tests + let wait_for_test = executor.spawn(Task::new(test_async())); + executor.wait(wait_for_test); executor.spawn(Task::new(keyboard::print_keypresses())); executor.spawn(Task::new(udp_echo::udp_echo_server())); executor.spawn(Task::new(tcp_echo::tcp_echo_server())); diff --git a/kernel/src/network/README.md b/kernel/src/network/README.md index 0f61a8e..ee55e55 100644 --- a/kernel/src/network/README.md +++ b/kernel/src/network/README.md @@ -42,7 +42,7 @@ To get a packet to the application layer, this is the pipeline it follows. rtl8139.rs --> processing.rs --> tcp_session.rs (if tcp packet) --> raw_socket.rs --> socket.rs -## TODOS +## TODO [x] PCI scanning for devices [x] RTL8139 Driver Code @@ -52,15 +52,11 @@ rtl8139.rs --> processing.rs --> tcp_session.rs (if tcp packet) --> raw_socket.r [x] Async IO [x] Timeouts [] Refactor so that all of networking is tested -[] Verify other parts of the packet [] Fix synchronization to be much cleaner -[] Clean up ugly stuff -[] search for todo and fix thoses -[] Everything in TODO.md ## Receiving packets -Firstly, in rtl8139.rs, we receive a device interrupt that a new packet has arrived (length + data). To keep the interrupt handler short, we copy the contents of the data to a buffer and wake the processor. This processor is in processing.rs and async-ly deserializes the buffer. After doing so, we determine which port it belongs to. See the discussion below about the design difference compared to the traditional notion of ports. (For tcp, we do a little more processing before sending it to the port by acknowledging the receipt of data, tcp_session.rs). Furthermore, once the port is deteremined, we can check the open ports to see if something truly needs the data. If so, we insert it into a port-specific buffer and wake the raw_socket (raw_socket.rs). The raw socket is a stream which is queried by a Socket (socket.rs) that owns it. This allows for async read. +Firstly, in rtl8139.rs, we receive a device interrupt that a new packet has arrived (length + data). To keep the interrupt handler short, we copy the contents of the data to a buffer and wake the processor. This processor is in processing.rs and async-ly deserializes the buffer. After doing so, we determine which port it belongs to. See the discussion below about the design difference compared to the traditional notion of ports. (For tcp, we do a little more processing before sending it to the port by acknowledging the receipt of data, tcp_session.rs). Furthermore, once the port is determined, we can check the open ports to see if something truly needs the data. If so, we insert it into a port-specific buffer and wake the raw_socket (raw_socket.rs). The raw socket is a stream which is queried by a Socket (socket.rs) that owns it. This allows for async read. ### u64 "ports" @@ -68,4 +64,4 @@ Traditionally, outward facing ports are from 0-65535, or a u16. But to handle so ### Timeouts -Because we are dealing with async in our socket api, when we put the raw_socket task to sleep - it won't wake up until a packet is received so that processing.rs can wake it. As a result, I introduced timeouts. Where the socket will either wake on receipt of a packet or by the handler of the timer interrupt. We register a timeout by saving a waker and how many epochs (iterations of the timer = ~1/18 of a second) into a minheap. Then in the handler, we can draw the minimum value from the heap and check if we've reached that epoch. If so, we wake the task and draw the next minimum. +Because we are dealing with async in our socket api, when we put the raw_socket task to sleep - it won't wake up until a packet is received so that processing.rs can wake it. As a result, I introduced timeouts. Where the socket will either wake on receipt of a packet or by the handler of the timer interrupt. We register a timeout by saving a waker and how many epochs (iterations of the timer = ~1/18 of a second) into a min-heap. Then in the handler, we can draw the minimum value from the heap and check if we've reached that epoch. If so, we wake the task and draw the next minimum. diff --git a/kernel/src/network/TODO.md b/kernel/src/network/TODO.md index b0e0df9..bdd0bd2 100644 --- a/kernel/src/network/TODO.md +++ b/kernel/src/network/TODO.md @@ -1,8 +1,7 @@ # TODOs * Refactor so that all of networking is tested (Fri) -* Fix synchronization to be much cleaner (Sat/Sun) -* Clean up ugly stuff -- final code review -* Search for todo and fix thoses (Sat) * Benchmarking (Sun) * Fix socket errors (Closing sockets... let's try different sources?) +* src before dest in gen functions +* TODOs \ No newline at end of file diff --git a/kernel/src/network/arp.rs b/kernel/src/network/arp.rs index ed17a0b..51b4106 100644 --- a/kernel/src/network/arp.rs +++ b/kernel/src/network/arp.rs @@ -1,10 +1,12 @@ +use crate::{check, network::layer::full_parse, serial_print, serial_println, mark_as_test, test_ok}; + use super::{ bytefield::{Bytefield16, Bytefield32, Bytefield48, Bytefield8}, ethernet::{EthType, EthernetPacket}, layer::{Layer, LayerType}, }; -use alloc::vec; use alloc::vec::Vec; +use alloc::{string::String, vec}; /// An arp packet, implements Layer (28 bytes) #[derive(Debug)] @@ -27,7 +29,7 @@ pub struct ArpPacket { pub src_ip: Bytefield32, /// The mac address of the receiver (0 if a request, this is the question field) pub dest_mac: Bytefield48, - /// The recepient IP, this is also part of the question if a request + /// The recipient IP, this is also part of the question if a request pub dest_ip: Bytefield32, } @@ -69,10 +71,15 @@ impl ArpPacket { operation: Bytefield16::new(if is_req { 1 } else { 2 }), // Request is 1, Reply is 2 src_mac, src_ip: Bytefield32::new(src_ip), - dest_mac, - dest_ip: Bytefield32::new(if is_req { 0 } else { dest_ip }), + dest_mac: if is_req { Bytefield48::new(0) } else { dest_mac }, + dest_ip: Bytefield32::new(dest_ip), } } + + /// If the arp is a response/reply. False if it's a request + pub fn is_response(&self) -> bool { + self.operation.val() == 2 + } } impl Layer for ArpPacket { @@ -83,8 +90,12 @@ impl Layer for ArpPacket { /// - eth_layer: a parsed Ethernet packet /// - bytevec: the data to parse, with trailing but it starts where the packet must begin fn parse(eth_layer: EthernetPacket, bytevec: &[u8]) -> (Self, usize, LayerType) { - let mut packet = ArpPacket::new(); // create an empty packet - + // create an empty packet + let mut packet = ArpPacket::new(); + // Check bytevec size + if bytevec.len() < Self::packet_size() as usize { + return (packet, 0, LayerType::Err) + } // Save ethernet packet and read 20 bytes let mut i = 0; packet.eth = eth_layer; @@ -126,3 +137,62 @@ impl Layer for ArpPacket { 28 } } + +pub fn test() -> Result<(), String> { + mark_as_test!("ARP Packet"); + + // Create an arp packet + let pkt: ArpPacket = ArpPacket::new(); + check!( + pkt.serialize().len() == ArpPacket::packet_size() as usize + EthernetPacket::packet_size() as usize, + "Check if arp packet + ethernet packet is correct size of serialized packet" + ); + + // Create an arp packet (checking request and reply conditions as well as src/dst ips) + let eth = EthernetPacket::gen(1, 2, EthType::Arp); + let arp_req = ArpPacket::gen(eth, 3, 4, true); + check!(arp_req.operation.swapped().val() == 1, "Arp packet is request"); + let eth2 = EthernetPacket::gen(1, 2, EthType::Arp); + let arp_res = ArpPacket::gen(eth2, 3, 4, false); + check!(arp_res.operation.swapped().val() == 2, "Arp packet is reply"); + check!(arp_res.src_ip.swapped().val() == 3, "Src ip is correct"); + check!(arp_res.dest_ip.swapped().val() == 4, "Dest ip is correct"); + + // Check serialization and deserialization property + let serialized = arp_res.serialize(); + let pkt = full_parse(&serialized); + check!(pkt.1.get_type() == LayerType::Arp, "Check it deserialized to an arp packet"); + check!(pkt.0 == serialized.len(), "Check it deserialized everything we serialized"); + let arp_pkt = pkt.1.unwrap_arp(); + // We swap endianness because the arp_res is in network order and the deserialized version is in host order + check!( + arp_pkt.hardware_type.swapped() == arp_res.hardware_type, + "Comparing packet hardware_type" + ); + check!( + arp_pkt.protocol_type.swapped() == arp_res.protocol_type, + "Comparing packet protocol_type" + ); + check!( + arp_pkt.hardware_address_length == arp_res.hardware_address_length, + "Comparing packet hardware_address_length" + ); + check!( + arp_pkt.protocol_address_length == arp_res.protocol_address_length, + "Comparing packet protocol_address_length" + ); + check!(arp_pkt.operation.swapped() == arp_res.operation, "Comparing packet operation"); + check!(arp_pkt.src_mac.swapped() == arp_res.src_mac, "Comparing packet src_mac"); + check!(arp_pkt.src_ip.swapped() == arp_res.src_ip, "Comparing packet src_ip"); + check!(arp_pkt.dest_mac.swapped() == arp_res.dest_mac, "Comparing packet dest_mac"); + check!(arp_pkt.dest_ip.swapped() == arp_res.dest_ip, "Comparing packet dest_ip"); + + // Check full parse on an incomplete packet + let mut serialized = arp_res.serialize(); + // remove last element, making this packet invalid + serialized.pop(); + let (_, err_pkt) = full_parse(&serialized); + check!(err_pkt.get_type() == LayerType::Err, "Deserializing an arp packet without enough data is an error"); + + test_ok!(); +} diff --git a/kernel/src/network/arp_table.rs b/kernel/src/network/arp_table.rs index ba337d0..d3496cd 100644 --- a/kernel/src/network/arp_table.rs +++ b/kernel/src/network/arp_table.rs @@ -1,4 +1,7 @@ -use crate::task::timeout::estimate_epoch; +use alloc::string::String; +use x86_64::instructions::{hlt, interrupts::without_interrupts}; + +use crate::{check, mark_as_test, serial_print, serial_println, task::timeout::estimate_epoch, test_ok}; /// An entry in the ARP table pub struct ArpEntry { @@ -16,11 +19,41 @@ impl ArpEntry { /// - ip: ip address of the entry /// - expires_in: calculating in how many epochs (timer interrupts, roughly 1/18 of a second) pub fn new(mac: u64, ip: u32, expires_in: u16) -> Self { - Self { mac, ip, expiration_epoch: estimate_epoch() + expires_in as u64 } + Self { + mac, + ip, + expiration_epoch: estimate_epoch() + expires_in as u64, + } } /// Return true if the arp entry is expired pub fn try_expire(&self) -> bool { - estimate_epoch() > self.expiration_epoch + estimate_epoch() >= self.expiration_epoch } -} \ No newline at end of file +} + +pub fn test() -> Result<(), String> { + mark_as_test!("ARP Entry"); + + let did_expire = without_interrupts(|| { + // Check if an arp entry is by default not going to expire + let entry = ArpEntry::new(0, 0, 1); + entry.try_expire() + }); + check!(!did_expire, "Expiration before a timer interrupt (waiting on 0)"); + + // Then check if an arp entry will expire after 1 interrupt + let entry = ArpEntry::new(0, 0, 1); + // Wait an interrupt + hlt(); + check!(entry.try_expire(), "Expiration after a single timer interrupt (waiting on 1)"); + + // Create an arp entry that expires only after 2 + let entry2 = ArpEntry::new(0, 0, 2); + hlt(); + check!(!entry2.try_expire(), "No expiration after a single timer interrupt (waiting on 2)"); + hlt(); + check!(entry2.try_expire(), "Expiration after two timer interrupts (waiting on 2)"); + + test_ok!(); +} diff --git a/kernel/src/network/bytefield.rs b/kernel/src/network/bytefield.rs index cdfb2ca..8caf028 100644 --- a/kernel/src/network/bytefield.rs +++ b/kernel/src/network/bytefield.rs @@ -1,4 +1,12 @@ -use core::{ops::{Index, IndexMut}, fmt}; +use core::{ + fmt, + mem::size_of, + ops::{Index, IndexMut}, +}; + +use alloc::{string::String, vec}; + +use crate::{check, serial_print, serial_println, mark_as_test, test_ok}; // N.B.: Bytefields will swap the endianness of the values when created // --> therefore serializing will create network byte order (big endian) Bytefields @@ -7,14 +15,14 @@ use core::{ops::{Index, IndexMut}, fmt}; /// A structure to represent data as an array of bytes /// Will store in both host and network byte-order. It is dependent on it's construction -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq)] pub struct Bytefield { pub data: [u8; N], } impl Bytefield { /// Return the bytefield with swapped endianness - pub fn swapped_endianness(self) -> Self { + pub fn swapped(self) -> Self { let mut data = self.data; data.reverse(); Self { data } @@ -104,3 +112,38 @@ bytefield_int!(Bytefield32, u32, 4); bytefield_int!(Bytefield48, u64, 6); bytefield_int!(Bytefield64, u64, 8); bytefield_int!(Bytefield128, u128, 16); + +pub fn test() -> Result<(), String> { + mark_as_test!("Bytefields"); + let bytefield32 = Bytefield32::new(0x01020304); + // Checking indexing into bytefield with inverted byte order + check!(bytefield32[0] == 1, "Indexing into bytefield is correct"); + check!(bytefield32[1] == 2, "Indexing into bytefield is correct"); + check!(bytefield32[2] == 3, "Indexing into bytefield is correct"); + check!(bytefield32[3] == 4, "Indexing into bytefield is correct"); + check!(bytefield32.val() == 0x04030201, "Val returns proper value"); + check!( + bytefield32.swapped().val() == 0x01020304, + "Val returns original value after swapping" + ); + + // Checking swapped endianness + check!(bytefield32.swapped() != bytefield32, "Swapped endianness works"); + check!(bytefield32.swapped().swapped() == bytefield32, "Swapped endianness works"); + + // Check all data-types + check!(Bytefield8::size() == size_of::(), "Correct data type for u8"); + check!(Bytefield16::size() == size_of::(), "Correct data type for u16"); + check!(Bytefield32::size() == size_of::(), "Correct data type for u32"); + check!(Bytefield64::size() == size_of::(), "Correct data type for u64"); + check!(Bytefield128::size() == size_of::(), "Correct data type for u128"); + check!(Bytefield48::size() == 48 / 8, "Correct data type for u48"); + + // Check read_inc function + let mut x = 0; + let vec = vec![0x00, 0x10, 0x02, 0x30, 0x04, 0x50]; + let num = Bytefield48::read_inc(&vec, &mut x); + check!(num.val() == 0x001002300450_u64, "Checking read_inc gets the correct value"); + check!(x == 6, "Checking read_inc increments by the correct value"); + test_ok!(); +} diff --git a/kernel/src/network/command_register.rs b/kernel/src/network/command_register.rs index 6de30dc..0e81dfd 100644 --- a/kernel/src/network/command_register.rs +++ b/kernel/src/network/command_register.rs @@ -1,3 +1,7 @@ +use alloc::string::String; + +use crate::{check, serial_print, serial_println, mark_as_test, test_ok}; + /// CommandRegister is an object to represent the command register for PCI devices /// See https://wiki.osdev.org/PCI#Command_Register #[derive(Debug, Clone)] @@ -137,3 +141,100 @@ impl CommandRegister { } // Bits 11-15 are reserved } + +pub fn test() -> Result<(), String> { + mark_as_test!("Command Register"); + + // Check bit manipulation of command register + let mut cr = CommandRegister::new(0); + check!(cr.data() == 0, ""); + check!(!cr.get_bus_master_bit(), ""); + check!(!cr.get_memory_space_bit(), ""); + check!(!cr.get_io_space_bit(), ""); + check!(!cr.get_serr_enable_bit(), ""); + check!(!cr.get_parity_err_res_bit(), ""); + check!(!cr.get_interrupt_disable_bit(), ""); + check!(!cr.get_special_cycles_bit(), ""); + check!(!cr.get_memory_write_invalidate_enable_bit(), ""); + check!(!cr.get_vga_palette_snoop_bit(), ""); + check!(!cr.get_fast_back_to_back_enable_bit(), ""); + + cr.set_bus_master_bit(true); + check!(cr.data() == 0b0000_0000_0000_0100, ""); + check!(cr.get_bus_master_bit(), ""); + cr.set_bus_master_bit(false); + + cr.set_interrupt_disable_bit(true); + check!(cr.data() == 0b0000_0100_0000_0000, ""); + check!(cr.get_interrupt_disable_bit(), ""); + cr.set_interrupt_disable_bit(false); + + cr.set_io_space_bit(true); + check!(cr.data() == 0b0000_0000_0000_0001, ""); + check!(cr.get_io_space_bit(), ""); + cr.set_io_space_bit(false); + + cr.set_memory_space_bit(true); + check!(cr.data() == 0b0000_0000_0000_0010, ""); + check!(cr.get_memory_space_bit(), ""); + cr.set_memory_space_bit(false); + + cr.set_parity_err_res_bit(true); + check!(cr.data() == 0b0000_0000_0100_0000, ""); + check!(cr.get_parity_err_res_bit(), ""); + cr.set_parity_err_res_bit(false); + + cr.set_serr_enable_bit(true); + check!(cr.data() == 0b0000_0001_0000_0000, ""); + check!(cr.get_serr_enable_bit(), ""); + cr.set_serr_enable_bit(false); + + cr = CommandRegister::new(0b0000_0000_0000_1000); + check!(!cr.get_bus_master_bit(), ""); + check!(!cr.get_memory_space_bit(), ""); + check!(!cr.get_io_space_bit(), ""); + check!(!cr.get_serr_enable_bit(), ""); + check!(!cr.get_parity_err_res_bit(), ""); + check!(!cr.get_interrupt_disable_bit(), ""); + check!(cr.get_special_cycles_bit(), ""); + check!(!cr.get_memory_write_invalidate_enable_bit(), ""); + check!(!cr.get_vga_palette_snoop_bit(), ""); + check!(!cr.get_fast_back_to_back_enable_bit(), ""); + + cr = CommandRegister::new(0b0000_0000_0001_0000); + check!(!cr.get_bus_master_bit(), ""); + check!(!cr.get_memory_space_bit(), ""); + check!(!cr.get_io_space_bit(), ""); + check!(!cr.get_serr_enable_bit(), ""); + check!(!cr.get_parity_err_res_bit(), ""); + check!(!cr.get_interrupt_disable_bit(), ""); + check!(!cr.get_special_cycles_bit(), ""); + check!(cr.get_memory_write_invalidate_enable_bit(), ""); + check!(!cr.get_vga_palette_snoop_bit(), ""); + check!(!cr.get_fast_back_to_back_enable_bit(), ""); + + cr = CommandRegister::new(0b0000_0000_0010_0000); + check!(!cr.get_bus_master_bit(), ""); + check!(!cr.get_memory_space_bit(), ""); + check!(!cr.get_io_space_bit(), ""); + check!(!cr.get_serr_enable_bit(), ""); + check!(!cr.get_parity_err_res_bit(), ""); + check!(!cr.get_interrupt_disable_bit(), ""); + check!(!cr.get_special_cycles_bit(), ""); + check!(!cr.get_memory_write_invalidate_enable_bit(), ""); + check!(cr.get_vga_palette_snoop_bit(), ""); + check!(!cr.get_fast_back_to_back_enable_bit(), ""); + + cr = CommandRegister::new(0b0000_0010_0000_0000); + check!(!cr.get_bus_master_bit(), ""); + check!(!cr.get_memory_space_bit(), ""); + check!(!cr.get_io_space_bit(), ""); + check!(!cr.get_serr_enable_bit(), ""); + check!(!cr.get_parity_err_res_bit(), ""); + check!(!cr.get_interrupt_disable_bit(), ""); + check!(!cr.get_special_cycles_bit(), ""); + check!(!cr.get_memory_write_invalidate_enable_bit(), ""); + check!(!cr.get_vga_palette_snoop_bit(), ""); + check!(cr.get_fast_back_to_back_enable_bit(), ""); + test_ok!(); +} diff --git a/kernel/src/network/constants.rs b/kernel/src/network/constants.rs index 0eae6e6..05784f3 100644 --- a/kernel/src/network/constants.rs +++ b/kernel/src/network/constants.rs @@ -7,8 +7,8 @@ pub const BROADCAST_ADDR: u32 = 0xFFFFFFFF; // broadcast IP address pub const BROADCAST_MAC: u64 = 0xFFFFFFFFFFFF; // broadcast MAC address // Common port numbers -pub const DHCP_CLIENT_PORT: u64 = 68; // client port for dhcp requests -pub const DHCP_SERVER_PORT: u64 = 67; // server port for dhcp requests +pub const DHCP_CLIENT_PORT: u16 = 68; // client port for dhcp requests +pub const DHCP_SERVER_PORT: u16 = 67; // server port for dhcp requests pub const ARP_PORT: u64 = u16::MAX as u64 + 2; // we are using ports above u16 to allow for extended our open "ports" map // RTL device specific info diff --git a/kernel/src/network/devices.rs b/kernel/src/network/devices.rs index f66a4d6..55af728 100644 --- a/kernel/src/network/devices.rs +++ b/kernel/src/network/devices.rs @@ -1,6 +1,8 @@ -use alloc::vec::Vec; +use alloc::{vec::Vec, string::String}; use x86_64::instructions::port::Port; +use crate::{serial_println, serial_print, check, mark_as_test, test_ok}; + use super::{ command_register::CommandRegister, constants::{PCI_CONFIG_ADDRESS, PCI_CONFIG_DATA}, @@ -174,7 +176,8 @@ fn pci_get_io_base(bus: u8, slot: u8) -> Option { // Must disable IO and memory space decode before accessing bar cr.set_memory_space_bit(false); cr.set_io_space_bit(false); - // todo: why doesn't this work? --> pci_set_cmd_reg(bus, slot, cr); + // todo: why doesn't this work? --> + // ! pci_set_cmd_reg(bus, slot, cr); for i in 0..5 { // Read the BAR let bar = pci_config_read_dword(bus, slot, 0, 0x10 + (i * 0x4)); @@ -264,3 +267,18 @@ pub fn scan_devices() -> Vec { // Return the results of the device scan results } + +pub fn test() -> Result<(), String> { + mark_as_test!("PCI-Devices"); + // check if the rtl8139 can be identified. (This device is known to be on qemu) + let devices = scan_devices(); + let mut found_rtl8139 = false; + for device in devices.iter() { + if device.class_code == PCIClassCodes::NetworkController && + device.vendor_id == 0x10EC && device.device_id == 0x8139 { + found_rtl8139 = true; + } + } + check!(found_rtl8139, "Searching for rtl8139"); + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 9741f83..663aa09 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -1,6 +1,6 @@ -use alloc::vec; +use alloc::{string::String, vec}; -use crate::crypto::rng::get_next_random_num; +use crate::{crypto::rng::get_next_random_num, serial_print, serial_println, check, network::{ethernet::EthType, layer::full_parse, constants::{DHCP_CLIENT_PORT, DHCP_SERVER_PORT}}, mark_as_test, test_ok}; use super::{ bytefield::{Bytefield128, Bytefield16, Bytefield32, Bytefield48, Bytefield8}, @@ -119,7 +119,11 @@ impl Layer for DHCPPacket { Self: Sized, { // Create a packet - let mut packet = DHCPPacket::new(); // create an empty packet + let mut packet = DHCPPacket::new(); + // Check bytevec size + if bytevec.len() < Self::packet_size() as usize { + return (packet, 0, LayerType::Err) + } let mut i = 0; // Fill in udp layer packet.udp = udp_layer; @@ -139,13 +143,14 @@ impl Layer for DHCPPacket { i += 64; // sname i += 128; // file i += 64; // options - // Ignoring those parts of the DHCP packet. - // Get data left to parse (might be variable length DHCP packet) + + // Ignoring those parts of the DHCP packet. + // Get data left to parse (might be variable length DHCP packet) let left_to_parse = packet.udp.length.val() - 308; // Increment I and assert its at least 300 i += left_to_parse as usize; - assert!(i >= 300); // 300 bytes - // Return the packet, the amount of data consumed, and the next layer type (end of parse) + assert!(i >= 300); + // Return the packet, the amount of data consumed, and the next layer type (end of parse) (packet, i, LayerType::End) } @@ -185,7 +190,7 @@ impl HasChecksum for DHCPPacket { // Starting vars let mut sum: u32 = 0; - // First we do the IP as a pseduo header + // First we do the IP as a pseudo header let ip = &self.udp.ip; sum += (ip.src_ip.data[0] as u32) | (ip.src_ip.data[1] as u32) << 8; sum += (ip.src_ip.data[2] as u32) | (ip.src_ip.data[3] as u32) << 8; @@ -211,3 +216,51 @@ impl HasChecksum for DHCPPacket { self.udp.ip.verify_checksum() } } + +pub fn test() -> Result<(), String> { + mark_as_test!("DHCP Packet"); + // Create an arp packet + let pkt: DHCPPacket = DHCPPacket::new(); + check!( + pkt.serialize().len() == DHCPPacket::packet_size() as usize + UDPPacket::packet_size() as usize + IPPacket::packet_size() as usize + EthernetPacket::packet_size() as usize, + "Check if dhcp + udp + ip + ethernet is correct size of serialized packet" + ); + + // Create a dhcp packet + use crate::network::ip::Protocol; + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let ip_size = pkt.serialize().len() as u16 - EthernetPacket::packet_size(); + let ip = IPPacket::gen(eth, ip_size, Protocol::Udp, 3, 4); + let udp = UDPPacket::gen(ip, DHCP_SERVER_PORT, DHCP_CLIENT_PORT, DHCPPacket::packet_size()); + let dhcp = DHCPPacket::gen(udp, None, 0x010203040506); + check!(dhcp.client_hardware_address.swapped().val() == 0x010203040506_u128 << (10 * 8) , "Is mac address correct"); + check!(dhcp.my_ip.val() == 0, "Is IP correct"); + + // Serialize and deserialize + let serialized = dhcp.serialize(); + let (size, pkt) = full_parse(&serialized); + check!(pkt.get_type() == LayerType::Dhcp, "Check it deserialized to a dhcp packet"); + check!(size == serialized.len(), "Check it deserialized everything we serialized"); + let dhcp_pkt = pkt.unwrap_dhcp(); + check!(dhcp_pkt.op_code == dhcp.op_code, "Check field"); + check!(dhcp_pkt.hardware_type == dhcp.hardware_type, "Check field"); + check!(dhcp_pkt.hardware_address_length == dhcp.hardware_address_length, "Check field"); + check!(dhcp_pkt.hops == dhcp.hops, "Check field"); + check!(dhcp_pkt.transaction_identifier.swapped() == dhcp.transaction_identifier, "Check field"); + check!(dhcp_pkt.seconds.swapped() == dhcp.seconds, "Check field"); + check!(dhcp_pkt.flags.swapped() == dhcp.flags, "Check field"); + check!(dhcp_pkt.client_ip.swapped() == dhcp.client_ip, "Check field"); + check!(dhcp_pkt.my_ip.swapped() == dhcp.my_ip, "Check field"); + check!(dhcp_pkt.server_ip.swapped() == dhcp.server_ip, "Check field"); + check!(dhcp_pkt.gateway_ip.swapped() == dhcp.gateway_ip, "Check field"); + check!(dhcp_pkt.client_hardware_address.swapped() == dhcp.client_hardware_address, "Check field"); + + // Check full parse on an incomplete packet + let mut serialized = dhcp.serialize(); + // remove last element, making this packet invalid + serialized.pop(); + let (_, err_pkt) = full_parse(&serialized); + check!(err_pkt.get_type() == LayerType::Err, "Deserializing a dhcp packet without enough data is an error"); + + test_ok!(); +} diff --git a/kernel/src/network/errors.rs b/kernel/src/network/errors.rs index 7627a5a..adff997 100644 --- a/kernel/src/network/errors.rs +++ b/kernel/src/network/errors.rs @@ -11,7 +11,7 @@ pub enum NetworkErrors { BadSocketState, /// A placeholder Network Error for unimplemented features (for development) FeatureNotAvailableYet, - /// If a timeout occured + /// If a timeout occurred Timeout, /// This is a special network error for when our TCP stream has closed ClosedSocket, diff --git a/kernel/src/network/ethernet.rs b/kernel/src/network/ethernet.rs index f0de3f5..f01b4c5 100644 --- a/kernel/src/network/ethernet.rs +++ b/kernel/src/network/ethernet.rs @@ -1,8 +1,10 @@ +use crate::{serial_print, check, serial_println, network::layer::full_parse, mark_as_test, test_ok}; + use super::{ bytefield::{Bytefield16, Bytefield48}, layer::{EmptyLayer, Layer, LayerType}, }; -use alloc::vec; +use alloc::{vec, string::String}; use alloc::vec::Vec; /// Ethernet type for the packet @@ -33,7 +35,7 @@ impl EthType { } /// An ethernet packet, implements Layer (14 bytes) -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct EthernetPacket { pub dest_mac: Bytefield48, // u48 pub src_mac: Bytefield48, // u48, @@ -75,7 +77,10 @@ impl Layer for EthernetPacket { Self: Sized, { let mut packet = EthernetPacket::new(); // create an empty packet - // Read 14 bytes + // Read at least 14 bytes + if bytevec.len() < Self::packet_size() as usize { + return (packet, 0, LayerType::Err) + } let mut i = 0; packet.dest_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); packet.src_mac = Bytefield48::read_inc(&bytevec[i..], &mut i); @@ -108,3 +113,47 @@ impl Layer for EthernetPacket { 14 } } + + + +pub fn test() -> Result<(), String> { + mark_as_test!("Ethernet Packet"); + // Create a ethernet packet + let pkt: EthernetPacket = EthernetPacket::new(); + check!(pkt.serialize().len() == EthernetPacket::packet_size() as usize, "Check if ethernet is correct size of serialized packet"); + + // Create another ethernet packet + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + check!(eth.dest_mac.swapped().val() == 1, "Check destination mac"); + check!(eth.src_mac.swapped().val() == 2, "Check src mac"); + check!(eth.packet_type == EthType::IPv4, "Check eth type"); + + // Serialize and deserialize + let serialized = eth.serialize(); + let (eth_pkt, count, next_type) = EthernetPacket::parse(EmptyLayer::new(), &serialized); + check!(next_type == LayerType::IP, "Check it's next layer is a IP packet"); + check!(count == serialized.len(), "Check it deserialized everything we serialized"); + check!(eth_pkt.src_mac.swapped() == eth.src_mac, "Check field is same"); + check!(eth_pkt.dest_mac.swapped() == eth.dest_mac, "Check field is same"); + check!(eth_pkt.packet_type == eth.packet_type, "Check field is same"); + + // Check next layer can be arp too + let eth2 = EthernetPacket::gen(1, 2, EthType::Arp); + let (eth_pkt2, count2, next_type2) = EthernetPacket::parse(EmptyLayer::new(), ð2.serialize()); + check!(next_type2 == LayerType::Arp, "Check it's next layer is a ARP packet"); + check!(count2 == serialized.len(), "Check it deserialized everything we serialized"); + check!(eth_pkt2.src_mac.swapped() == eth2.src_mac, "Check field is same"); + check!(eth_pkt2.dest_mac.swapped() == eth2.dest_mac, "Check field is same"); + check!(eth_pkt2.packet_type == eth2.packet_type, "Check field is same"); + + // Check full parse on an ethernet packet + let eth3 = EthernetPacket::gen(1, 2, EthType::IPv4); + let (_, eth_pkt3) = full_parse(ð3.serialize()); + check!(eth_pkt3.get_type() == LayerType::Err, "Serializing an ethernet packet without following IP packet is an error"); + + // Check parse on small vector to return an error as well + let random_vec = vec![0, 1, 40, 10, 40, 91, 2, 0, 9, 10, 11, 12]; + let (_, eth_pkt4) = full_parse(&random_vec); + check!(eth_pkt4.get_type() == LayerType::Err, "Deserializing an ethernet packet without enough data is an error"); + test_ok!(); +} diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index 6ad1bc3..36ada72 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -8,7 +8,7 @@ use crate::network::ethernet::{EthType, EthernetPacket}; use crate::network::ip::{IPPacket, Protocol}; use crate::network::layer::{Layer, LayerType}; use crate::network::raw_socket::RawSocket; -use crate::network::rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}; +use crate::network::rtl8139::NET_INFO; use crate::network::udp::UDPPacket; use crate::task::executor::Executor; use crate::task::Task; @@ -16,16 +16,15 @@ use crate::{network::constants::DHCP_SERVER_PORT, println}; /// Initialize all the network related stuff pub fn init() { - // todo bundle the init phases + // todo: bundle the init phases } /// Initialize dhcp by finding my IP address pub async fn init_dhcp(wait_timeout: u8) -> bool { // open a socket (shouldn't fail since we are in the init process) - let mut socket = RawSocket::new(DHCP_CLIENT_PORT, 1).unwrap(); + let mut socket = RawSocket::new(DHCP_CLIENT_PORT as u64, 1).unwrap(); // Get the network driver object - disable_network_interrupts(); let rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); @@ -33,7 +32,7 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { let eth = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::IPv4); let ip_size = DHCPPacket::packet_size() + UDPPacket::packet_size(); let ip = IPPacket::gen(eth, ip_size, Protocol::Udp, 0x0, BROADCAST_ADDR); - let udp = UDPPacket::gen(ip, DHCP_CLIENT_PORT as u16, DHCP_SERVER_PORT as u16, DHCPPacket::packet_size()); + let udp = UDPPacket::gen(ip, DHCP_CLIENT_PORT, DHCP_SERVER_PORT, DHCPPacket::packet_size()); let mut dhcp = DHCPPacket::gen(udp, None, rtl_dev_info.mac_address.unwrap()); dhcp.udp.ip.calculate_checksum(); dhcp.calculate_checksum(); @@ -43,7 +42,6 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { rtl_dev_info.send_packet(&packet_data); // Release the driver object drop(rtl_dev_guard); - enable_network_interrupts(); // Wait for response let mut retries = 0; @@ -60,12 +58,10 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { break; } else { // Resend the packet after a timeout or other error - disable_network_interrupts(); let rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); - rtl_dev_info.send_packet(&packet_data); // send another packet - drop(rtl_dev_guard); - enable_network_interrupts(); + // send another packet + rtl_dev_info.send_packet(&packet_data); } // Don't try forever retries += 1; @@ -77,22 +73,17 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { // Close the raw socket socket.close(); - // Get the driver object again - disable_network_interrupts(); - let mut rtl_dev_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_guard.get_mut().unwrap(); + // Get the network stack object + let mut rtl_dev_config = NET_INFO.config.lock(); if pkt_data.get_type() == LayerType::Dhcp { // With the driver object, record the information from the DHCP server let dhcp_res = pkt_data.unwrap_dhcp(); - rtl_dev_info.my_ip_address = Some(dhcp_res.my_ip.val()); - rtl_dev_info.dhcp_server_ip = Some(dhcp_res.server_ip.val()); + rtl_dev_config.my_ip_address = Some(dhcp_res.my_ip.val()); + rtl_dev_config.dhcp_server_ip = Some(dhcp_res.server_ip.val()); // Debug print my IP - let ip = dhcp_res.my_ip.swapped_endianness(); + let ip = dhcp_res.my_ip.swapped(); println!("[INFO] IP-Address Assigned As {}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]); } - // Release the driver object - drop(rtl_dev_guard); - enable_network_interrupts(); true } diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index 76ed4ef..4952451 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -1,7 +1,10 @@ -use alloc::vec; use alloc::vec::Vec; +use alloc::{string::String, vec}; -use crate::crypto::rng::get_next_random_num; +use crate::network::ethernet::EthType; +use crate::network::layer::{full_parse, EmptyLayer}; +use crate::{check, mark_as_test, serial_println, test_ok}; +use crate::{crypto::rng::get_next_random_num, serial_print}; use super::{ bytefield::{Bytefield16, Bytefield32, Bytefield8}, @@ -11,7 +14,7 @@ use super::{ /// Protocol for IP /// - ICMP (in development), TCP, UDP -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum Protocol { Icmp = 1, @@ -55,7 +58,7 @@ pub struct IPPacket { pub checksum: Bytefield16, /// The sender's IP address pub src_ip: Bytefield32, - /// The recepient's IP address + /// The recipient's IP address pub dest_ip: Bytefield32, } @@ -86,7 +89,7 @@ impl IPPacket { pub fn gen(eth_layer: EthernetPacket, data_length: u16, protocol: Protocol, src_ip: u32, dest_ip: u32) -> Self { // Generate a unique ID for the packet let identification = (get_next_random_num() % u16::MAX as u32) as u16; - + // Construct the packet IPPacket { eth: eth_layer, @@ -117,6 +120,9 @@ impl Layer for IPPacket { { // create an empty packet let mut packet = IPPacket::new(); + if bytevec.len() < Self::packet_size() as usize { + return (packet, 0, LayerType::Err); + } // Save ethernet packet and read 20 bytes let mut i = 0; packet.eth = eth_layer; @@ -190,18 +196,79 @@ impl HasChecksum for IPPacket { eth: EthernetPacket::new(), version_hlen: self.version_hlen, type_of_service: self.type_of_service, - total_length: self.total_length.swapped_endianness(), - identification: self.identification.swapped_endianness(), - flags_fragment_offset: self.flags_fragment_offset.swapped_endianness(), + total_length: self.total_length.swapped(), + identification: self.identification.swapped(), + flags_fragment_offset: self.flags_fragment_offset.swapped(), ttl: self.ttl, protocol: self.protocol, checksum: Bytefield16::new(0), - src_ip: self.src_ip.swapped_endianness(), - dest_ip: self.dest_ip.swapped_endianness(), + src_ip: self.src_ip.swapped(), + dest_ip: self.dest_ip.swapped(), }; ip.calculate_checksum(); // Return if the checksum is a match - ip.checksum.val() == self.checksum.swapped_endianness().val() + ip.checksum.val() == self.checksum.swapped().val() } } + +pub fn test() -> Result<(), String> { + mark_as_test!("IP Packet"); + + // Create an ip packet, check it serializes correctly + let pkt: IPPacket = IPPacket::new(); + check!( + pkt.serialize().len() == (EthernetPacket::packet_size() + IPPacket::packet_size()) as usize, + "Check serialization size" + ); + + // Create another IP packet packet + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let ip_udp = IPPacket::gen(eth.clone(), 0, Protocol::Udp, 3, 4); + let ip_tcp = IPPacket::gen(eth, 0, Protocol::Tcp, 3, 4); + check!(ip_udp.src_ip.swapped().val() == 3, "Check src ip"); + check!(ip_udp.dest_ip.swapped().val() == 4, "Check dest ip"); + check!(ip_udp.protocol == Protocol::Udp, "Check protocol type"); + + // Serialize and deserialize + let serialized_udp = ip_udp.serialize(); + let serialized_tcp = ip_tcp.serialize(); + let (eth_layer_udp, i1, _) = EthernetPacket::parse(EmptyLayer::new(), &serialized_udp); + let (eth_layer_tcp, i2, _) = EthernetPacket::parse(EmptyLayer::new(), &serialized_tcp); + let (ip_pkt_udp, total_udp, should_be_udp) = IPPacket::parse(eth_layer_udp, &serialized_udp[i1..]); + let (ip_pkt_tcp, total_tcp, should_be_tcp) = IPPacket::parse(eth_layer_tcp, &serialized_tcp[i2..]); + check!(should_be_udp == LayerType::Udp, "Check it's next layer is a UDP packet"); + check!(should_be_tcp == LayerType::Tcp, "Check it's next layer is a TCP packet"); + check!( + total_udp + i1 == serialized_udp.len(), + "Check it deserialized everything we serialized" + ); + check!( + total_tcp + i2 == serialized_tcp.len(), + "Check it deserialized everything we serialized" + ); + check!(ip_pkt_udp.src_ip.swapped() == ip_udp.src_ip, "Check field is same"); + check!(ip_pkt_udp.dest_ip.swapped() == ip_udp.dest_ip, "Check field is same"); + check!(ip_pkt_udp.protocol == ip_udp.protocol, "Check field is same"); + check!(ip_pkt_tcp.src_ip.swapped() == ip_tcp.src_ip, "Check field is same"); + check!(ip_pkt_tcp.dest_ip.swapped() == ip_tcp.dest_ip, "Check field is same"); + check!(ip_pkt_tcp.protocol == ip_tcp.protocol, "Check field is same"); + + // Check full parse on an IP packet + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let mut ip_vec = IPPacket::gen(eth, 0, Protocol::Udp, 3, 4).serialize(); + let (_, ip_res) = full_parse(&ip_vec); + check!( + ip_res.get_type() == LayerType::Err, + "Deserializing an ip packet without following Udp packet is an error" + ); + + // Check parse on smaller vector than expected should also return an error + ip_vec.pop(); + let (_, ip_res2) = full_parse(&ip_vec); + check!( + ip_res2.get_type() == LayerType::Err, + "Deserializing an ip packet without enough data is an error" + ); + test_ok!(); +} diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index 28cd197..72cec01 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -1,6 +1,12 @@ use alloc::boxed::Box; +use alloc::string::String; use alloc::vec::Vec; +use crate::network::constants::{DHCP_SERVER_PORT, DHCP_CLIENT_PORT}; +use crate::network::ethernet::EthType; +use crate::network::ip::Protocol; +use crate::{mark_as_test, test_ok, check, serial_println, serial_print}; + use super::arp::ArpPacket; use super::dhcp::DHCPPacket; use super::ethernet::EthernetPacket; @@ -63,6 +69,9 @@ pub trait HasChecksum { fn verify_checksum(&mut self) -> bool; } +/// Calculate the checksum on a body of data +/// Will finish off the checksum calculation (by swapping bytes and 1's complement), +/// therefore we provide prev_sum to account for any pre-summing steps pub fn calculate_checksum_inner(body: &[u8], prev_sum: u32) -> u16 { let mut body_len = body.len(); let mut sum = prev_sum; @@ -282,3 +291,44 @@ pub fn full_parse(packet: &[u8]) -> (usize, PacketData) { } } } + +pub fn test() -> Result<(), String> { + // Some general tests for full parse, but this should be mostly tested by each layer's test + mark_as_test!("Full Parse"); + let eth = EthernetPacket::gen(0, 0, EthType::IPv4); + let eth_for_arp = EthernetPacket::gen(0, 0, EthType::Arp); + let arp = ArpPacket::gen(eth_for_arp, 0, 0, false); + let ip_for_udp = IPPacket::gen(eth.clone(), UDPPacket::packet_size(), Protocol::Udp, 0, 0); + let ip_for_dhcp = IPPacket::gen(eth.clone(), UDPPacket::packet_size() + DHCPPacket::packet_size(), Protocol::Udp, 0, 0); + let ip_for_tcp = IPPacket::gen(eth.clone(), TCPPacket::packet_size(), Protocol::Tcp, 0, 0); + let udp = UDPPacket::gen(ip_for_udp, 0, 0, 0); + let udp_for_dhcp = UDPPacket::gen(ip_for_dhcp, DHCP_SERVER_PORT, DHCP_CLIENT_PORT, DHCPPacket::packet_size()); + let dhcp = DHCPPacket::gen(udp_for_dhcp, None, 0); + let tcp = TCPPacket::gen(ip_for_tcp, 0, 0); + + let packets = [ + (arp.serialize(), LayerType::Arp), + (udp.serialize(), LayerType::Udp), + (dhcp.serialize(), LayerType::Dhcp), + (tcp.serialize(), LayerType::Tcp), + ]; + + for pkt in packets.iter() { + let mut data = pkt.0.clone(); + let (_, undo_pkt1) = full_parse(&data); + check!(undo_pkt1.get_type() == pkt.1, "Checking type of deserialized packet"); + + // 1 above in amount + data.push(0); + let (_, err_pkt1) = full_parse(&data); + check!(err_pkt1.get_type() == pkt.1, "Checking type of deserialized packet"); + + // 1 under in amount + data.pop(); + data.pop(); + let (_, err_pkt2) = full_parse(&data); + check!(err_pkt2.get_type() == LayerType::Err, "Checking type of deserialized packet"); + } + + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/mod.rs b/kernel/src/network/mod.rs index 95d3e60..092a6c4 100644 --- a/kernel/src/network/mod.rs +++ b/kernel/src/network/mod.rs @@ -3,6 +3,7 @@ pub mod init; pub mod rtl8139; pub mod socket; pub mod errors; +pub mod test; // Internal workings mod constants; diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs index 46818de..0a42231 100644 --- a/kernel/src/network/netsync.rs +++ b/kernel/src/network/netsync.rs @@ -1,18 +1,34 @@ +use alloc::string::String; use spin::MutexGuard; -use super::rtl8139::{NetworkConfig, RTL8139}; +use crate::{test_ok, mark_as_test, serial_println, serial_print, check, network::rtl8139::{are_network_interrupts_enabled, NET_INFO}}; -struct InterruptGuard {} +use super::rtl8139::{NetworkConfig, RTL8139, enable_network_interrupts, disable_network_interrupts}; + +struct NetworkInterruptGuard {} + +impl NetworkInterruptGuard { + pub fn new() -> Self { + disable_network_interrupts(); + NetworkInterruptGuard {} + } +} + +impl Drop for NetworkInterruptGuard { + fn drop(&mut self) { + // re-enable network interrupts when we drop + enable_network_interrupts(); + } +} /// A guard for disabling interrupts, accessing interrupt-sensitive locks, and then re-enabling interrupts -/// - in progress - doesn't function correctly yet -pub struct NetworkInterruptsGuard<'a> { +pub struct NetworkGuard<'a> { /// Internal mutex guard protected by the network interrupts guard data: MutexGuard<'a, Option>, + _interrupt_guard: NetworkInterruptGuard, } - -impl NetworkInterruptsGuard<'_> { +impl NetworkGuard<'_> { /// Get the internals mutable pub fn get_mut(&mut self) -> Option<&mut RTL8139> { return self.data.as_mut(); @@ -24,17 +40,11 @@ impl NetworkInterruptsGuard<'_> { } } -impl Drop for NetworkInterruptsGuard<'_> { - fn drop(&mut self) { - // re-enable network interrupts when we drop - // drop(self.data); - // enable_network_interrupts(); - // println!("Enabling network interrupts") - } -} /// A driver that is "safe" to access without deadlocking with interrupt handler pub struct SafeRTL8139 { + // NOTE: the field ordering here is critical because it decides the drop order. First the main + // mutex is released, then interrupts are re-enabled data: spin::Mutex>, pub config: spin::Mutex, } @@ -46,15 +56,17 @@ impl SafeRTL8139 { } /// Get the internals without deadlocking with interrupt handler by disabling interrupts with the network interrupts guard - pub fn lock(&self) -> NetworkInterruptsGuard { - // disable_network_interrupts(); - return NetworkInterruptsGuard { data: self.data.lock() }; + pub fn lock(&self) -> NetworkGuard { + NetworkGuard { + data: self.data.lock(), + _interrupt_guard: NetworkInterruptGuard::new() + } } /// Get the internals without disabling interrupts --> this is unsafe /// Should only be used in interrupt handler pub unsafe fn lock_no_disable(&self) -> MutexGuard> { - return self.data.lock(); + self.data.lock() } } @@ -80,3 +92,19 @@ impl InterruptCounter { self.data -= 1; } } + +pub fn test() -> Result<(), String> { + mark_as_test!("Synchronization"); + // Check lock with disabling + check!(are_network_interrupts_enabled(), "Network interrupts are enabled by default"); + let lock = NET_INFO.lock(); + check!(!are_network_interrupts_enabled(), "Locking disables interrupts"); + drop(lock); + check!(are_network_interrupts_enabled(), "Unlocking re-enables interrupts"); + + // check lock without disabling + let lock2 = unsafe { NET_INFO.lock_no_disable() }; + check!(are_network_interrupts_enabled(), "Lock no disable has no affect"); + drop(lock2); + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/network_query.rs b/kernel/src/network/network_query.rs index 76c5b75..57c424f 100644 --- a/kernel/src/network/network_query.rs +++ b/kernel/src/network/network_query.rs @@ -1,47 +1,121 @@ +use alloc::string::String; use futures_util::StreamExt; +use crate::{mark_as_test, check, test_ok, serial_print, serial_println}; + use super::{ arp::ArpPacket, constants::{ARP_PORT, BROADCAST_MAC}, ethernet::{EthType, EthernetPacket}, layer::{Layer, LayerType}, raw_socket::RawSocket, - rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, + rtl8139::NET_INFO, bytefield::{Bytefield32, Bytefield48}, }; /// A module for querying things with the network stack pub struct NetworkQuery {} impl NetworkQuery { + pub fn parse_ip(ip: &str) -> Option { + if ip == "localhost" { + return NetworkQuery::parse_ip("127.0.0.1"); + } + let mut data = Bytefield32::new(0); + let mut i = 0; + for tok in ip.split('.') { + let maybe_byte = tok.parse::(); + if maybe_byte.is_err() || i >= 4 { + return None; + } + data[3 - i] = maybe_byte.unwrap(); + i += 1; + } + if i != 4 { + return None; + } + Some(data.val()) + } + + fn parse_hex(chr: char) -> Option { + let data = match chr { + '0' => 0, + '1' => 1, + '2' => 2, + '3' => 3, + '4' => 4, + '5' => 5, + '6' => 6, + '7' => 7, + '8' => 8, + '9' => 9, + 'a' => 10, + 'b' => 11, + 'c' => 12, + 'd' => 13, + 'e' => 14, + 'f' => 15, + _ => 16 + }; + if data == 16 { + None + } else { + Some(data) + } + } + + pub fn parse_mac(ip: &str) -> Option { + let mut data = Bytefield48::new(0); + let mut i = 0; + for tok in ip.split(':') { + if tok.len() != 2 { + return None; + } + let byte_first_half = NetworkQuery::parse_hex(tok.chars().nth(0).unwrap()); + let byte_second_half = NetworkQuery::parse_hex(tok.chars().nth(1).unwrap()); + if byte_first_half.is_none() || byte_second_half.is_none() || i >= 6 { + return None; + } + data[5 - i] = (byte_first_half.unwrap() * 16) + byte_second_half.unwrap(); + i += 1; + } + if i != 6 { + return None; + } + Some(data.val()) + } + /// Get a mac address from an ip address /// wait_timeout is how many seconds until giving up - pub async fn get_mac_from_ip(wait_timeout: u32, ip: u32) -> Option { - // Acquire the driver - disable_network_interrupts(); - let mut rtl_dev_info_locked = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); + pub async fn get_mac_from_ip(wait_timeout: u32, dst_ip: u32) -> Option { + let mut rtl_dev_config = NET_INFO.config.lock(); // iterate through entries in the arp table - for (index, entry) in rtl_dev_info.arp_table.iter().enumerate() { + for (index, entry) in rtl_dev_config.arp_table.iter().enumerate() { // if entry matches, we can return from the cache - if entry.ip == ip { + if entry.ip == dst_ip { // If entry is expired, remove and break if entry.try_expire() { - rtl_dev_info.arp_table.remove(index); + rtl_dev_config.arp_table.remove(index); break; } return Some(entry.mac); } } + // Extract my ip + let my_ip: u32 = rtl_dev_config.my_ip_address.unwrap(); + drop(rtl_dev_config); + + // Acquire the driver + let mut rtl_dev_info_locked = NET_INFO.lock(); + let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); // Otherwise, we send an arp packet and query the IP let eth_layer = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::Arp); - let arp_layer = ArpPacket::gen(eth_layer, rtl_dev_info.my_ip_address.unwrap(), ip, true); + let arp_layer = ArpPacket::gen(eth_layer, my_ip, dst_ip, true); rtl_dev_info.send_packet(&arp_layer.serialize()); // and release the driver drop(rtl_dev_info_locked); - enable_network_interrupts(); // wait for response by creating an ARP socket let mut socket = RawSocket::new(ARP_PORT, 3).unwrap(); @@ -59,23 +133,18 @@ impl NetworkQuery { continue; } let arp_pkt = pkt_data.unwrap_arp(); - if arp_pkt.src_ip.val() != ip { + if arp_pkt.src_ip.val() != dst_ip { continue; } // Once we've unwrapped the packet, we can close the socket and return the sender mac socket.close(); return Some(arp_pkt.src_mac.val()); } else { - // If we timed-out - disable_network_interrupts(); - // acquire the driver again + // If we timed-out, acquire the driver again let rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); // send the packet rtl_dev_info.send_packet(&arp_layer.serialize()); - // release the driver - drop(rtl_dev_guard); - enable_network_interrupts(); } // Count retries and if we exceed the limit, we die retries += 1; @@ -86,3 +155,31 @@ impl NetworkQuery { } } } + + +pub async fn test() -> Result<(), String> { + mark_as_test!("NetworkQuery"); + let err_ip = NetworkQuery::parse_ip("127.0.01"); + check!(err_ip.is_none(), "Cannot properly parse 127.0.01"); + let err_ip2 = NetworkQuery::parse_ip("127.0.0.1.0"); + check!(err_ip2.is_none(), "Cannot properly parse 127.0.0.1.0"); + let err_ip3 = NetworkQuery::parse_ip("127.0.0.a"); + check!(err_ip3.is_none(), "Cannot properly parse 127.0.0.a"); + let err_ip4 = NetworkQuery::parse_ip("a.0.0.1"); + check!(err_ip4.is_none(), "Cannot properly parse a.0.0.1"); + + let localhost_ip = NetworkQuery::parse_ip("127.0.0.1"); + check!(localhost_ip.is_some(), "Can properly parse 127.0.0.1"); + let test_ip = NetworkQuery::parse_ip("1.2.3.4"); + check!(test_ip.unwrap() == 0x01020304, "Can properly parse 1.2.3.4"); + + let dest_ip = NetworkQuery::parse_ip("10.0.2.2").unwrap(); + let test_mac = NetworkQuery::get_mac_from_ip(3, dest_ip).await; + check!(test_mac.is_some(), "Checking found mac is what we expected"); + + let parse_mac = NetworkQuery::parse_mac("00:01:02:03:04:05"); + check!(parse_mac.unwrap() == 0x000102030405, "Parsing mac address"); + let pretend_mac = NetworkQuery::parse_mac("52:55:0a:00:02:02"); + check!(pretend_mac.is_some(), "Trying with letters now"); + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 5b08939..4e5e618 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -1,4 +1,4 @@ -use alloc::collections::VecDeque; +use alloc::{collections::VecDeque, string::String, vec}; use conquer_once::spin::OnceCell; use crossbeam_queue::ArrayQueue; use futures_util::{task::AtomicWaker, Stream, StreamExt}; @@ -12,12 +12,12 @@ use crate::{ ethernet::{EthType, EthernetPacket}, ip::{IPPacket, Protocol}, layer::{HasChecksum, Layer, LayerType, PacketData}, - raw_socket::wake_sockets, - rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, + raw_socket::{wake_sockets, RawSocket}, + rtl8139::NET_INFO, tcp::TCPPacket, - tcp_session::{SessionAction, TCPSession}, + tcp_session::{SessionAction, TCPSession}, udp::UDPPacket, }, - println, + println, mark_as_test, test_ok, check, serial_print, serial_println, }; use core::{ @@ -25,7 +25,10 @@ use core::{ task::{Context, Poll}, }; -use super::layer::full_parse; +use super::{ + layer::full_parse, + rtl8139::{disable_network_interrupts, enable_network_interrupts}, +}; /// A waker for waking the pending packet stream static PROCESS_VEC_WAKER: AtomicWaker = AtomicWaker::new(); @@ -35,6 +38,11 @@ pub struct PacketBuf { } /// An array queue for data to parse static PENDING_DATA: OnceCell> = OnceCell::uninit(); +static MAX_WINDOW_SIZE: u16 = 20; +static mut WINDOW_SIZE: u16 = MAX_WINDOW_SIZE; +pub fn get_window_size() -> u16 { + unsafe { WINDOW_SIZE } +} /// A empty struct with behavior /// - the idea is its a queue for the interrupt handler to enqueue vectors to represent packets @@ -48,7 +56,7 @@ impl PendingProcessingStream { fn new() -> Self { // Initialize the pending data array queue with max size 20 PENDING_DATA - .try_init_once(|| ArrayQueue::new(20)) + .try_init_once(|| ArrayQueue::new(MAX_WINDOW_SIZE as usize)) .expect("PendingProcessingStream::new should only be called once"); PendingProcessingStream { _private: () } } @@ -66,6 +74,7 @@ impl Stream for PendingProcessingStream { let queue = PENDING_DATA.try_get().expect("not initialized"); // Get the next entry and register to be woken up let data = queue.pop(); + unsafe { WINDOW_SIZE = MAX_WINDOW_SIZE - queue.len() as u16 }; // Release the driver enable_network_interrupts(); PROCESS_VEC_WAKER.register(cx.waker()); @@ -94,6 +103,30 @@ pub(crate) fn add_pkt_data(data: PacketBuf) { } } +/// Put a packet onto the local interface +/// Note: this isn't very fast since we are translating to a deconstructed format +/// (so not very useful beyond testing our network stack) +/// If you IPC, don't use this +pub(crate) fn local_send_pkt(data: &[u8]) { + // Try to get the array queue + if let Ok(queue) = PENDING_DATA.try_get() { + // Simple transfer into an array + let mut good: [u8; 1500] = [0; 1500]; + for (i, val) in data.iter().enumerate() { + good[i] = *val; + } + // And push data + if queue.push(PacketBuf { buffer: good, length: data.len() as u16 } ).is_err() { + println!("[WARN] packet queue full; dropping packet"); + } else { + // If we pushed data, wake the processing thing + PROCESS_VEC_WAKER.wake(); + } + } else { + println!("[WARN] packet queue uninitialized"); + } +} + /// Start the processing of packets /// - this function never terminates until we receive None from the raw packet stream (which won't ever happen) pub async fn init_packet_processing() { @@ -107,44 +140,43 @@ pub async fn init_packet_processing() { // Don't deal with unrecognized packets (will fail assert) continue; } - if amount_parsed_and_pkt.0 != pkt_data.length as usize && (pkt_data.length as usize) >= 64 { + if amount_parsed_and_pkt.0 != pkt_data.length as usize && (pkt_data.length as usize) > 64 { // Assert we had a proper amount of data processed without erroring out - println!("[ERR] amount parsed: {} --> true length: {}", amount_parsed_and_pkt.0, pkt_data.length); - assert!(amount_parsed_and_pkt.0 == pkt_data.length as usize || (pkt_data.length as usize) < 64); + println!( + "[ERR] amount parsed: {} --> true length: {}", + amount_parsed_and_pkt.0, pkt_data.length + ); + assert!(amount_parsed_and_pkt.0 == pkt_data.length as usize || (pkt_data.length as usize) <= 64); } - // Get the driver configuration - disable_network_interrupts(); - let mut rtl_dev_info_guard = NET_INFO.lock(); - // Get the device fields - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // Get the network stack configuration let mut rtl_dev_config = NET_INFO.config.lock(); (|| { match amount_parsed_and_pkt.1 { PacketData::ARP(arp) => { - if arp.dest_ip.val() == rtl_dev_info.my_ip_address.unwrap_or(0) || arp.dest_ip.val() == BROADCAST_ADDR { - // WE GOT A REQUEST, so create a response - let eth_layer = EthernetPacket::gen(arp.src_mac.val(), rtl_dev_info.mac_address.unwrap(), EthType::Arp); - let ip_address = rtl_dev_info.my_ip_address.unwrap_or(BROADCAST_ADDR); - let arp_layer = ArpPacket::gen(eth_layer, ip_address, arp.src_ip.val(), false); - let arp_pkt = arp_layer.serialize(); - // and send it - rtl_dev_info.send_packet(&arp_pkt); - } else { + if arp.is_response() { // WE GOT A RESPONSE, saving into the arp table with an expiration of an hour - rtl_dev_info.arp_table.push(ArpEntry::new(arp.src_mac.val(), arp.src_ip.val(), 64800)); + rtl_dev_config.arp_table.push(ArpEntry::new(arp.src_mac.val(), arp.src_ip.val(), 64800)); // ? Note: there is an attack vector where I send infinite ARP responses and fill up our arp table // ? For now, I am content with not solving this issue... // If there was some process listening on the ARP "port" -> then we have to upstream the packet - if rtl_dev_info.open_ports.contains(&ARP_PORT) { + if rtl_dev_config.open_ports.contains(&ARP_PORT) { // if we are listening on the port, try to insert initialize it into the map - if !rtl_dev_info.ports.contains_key(&ARP_PORT) { - rtl_dev_info.ports.insert(ARP_PORT, VecDeque::new()); + if !rtl_dev_config.ports.contains_key(&ARP_PORT) { + rtl_dev_config.ports.insert(ARP_PORT, VecDeque::new()); } // Push the packet into the port structure and wake the port - rtl_dev_info.ports.get_mut(&ARP_PORT).unwrap().push_back(Ok(PacketData::ARP(arp))); + rtl_dev_config.ports.get_mut(&ARP_PORT).unwrap().push_back(Ok(PacketData::ARP(arp))); wake_sockets(ARP_PORT); } + } else if arp.dest_ip.val() == rtl_dev_config.my_ip_address.unwrap_or(0) || arp.dest_ip.val() == BROADCAST_ADDR { + // WE GOT A REQUEST, so create a response + let eth_layer = EthernetPacket::gen(arp.src_mac.val(), rtl_dev_config.my_mac_address.unwrap(), EthType::Arp); + let ip_address = rtl_dev_config.my_ip_address.unwrap_or(BROADCAST_ADDR); + let arp_layer = ArpPacket::gen(eth_layer, ip_address, arp.src_ip.val(), false); + let arp_pkt = arp_layer.serialize(); + // and send it + NET_INFO.lock().get_mut().unwrap().send_packet(&arp_pkt); } } PacketData::DHCP(mut dhcp) => { @@ -156,14 +188,14 @@ pub async fn init_packet_processing() { let dest_port = dhcp.udp.dest_port.val() as u64; println!("[HANDLER] Found DHCP packet"); // If we are listening on the DHCP port - if rtl_dev_info.open_ports.contains(&dest_port) { + if rtl_dev_config.open_ports.contains(&dest_port) { println!("[HANDLER] Port {} is open", dest_port); // Try to initialize the port data structure - if !rtl_dev_info.ports.contains_key(&dest_port) { - rtl_dev_info.ports.insert(dest_port, VecDeque::new()); + if !rtl_dev_config.ports.contains_key(&dest_port) { + rtl_dev_config.ports.insert(dest_port, VecDeque::new()); } // Push back the dhcp packet and wake the port - rtl_dev_info + rtl_dev_config .ports .get_mut(&dest_port) .unwrap() @@ -173,20 +205,23 @@ pub async fn init_packet_processing() { } PacketData::UDP(mut udp) => { if !udp.verify_checksum() { - println!("Cannot verify checksum"); // return to leave the match statement (its wrapped in a closure), dropping this packet return; } // UDP packet let dest_port = udp.dest_port.val() as u64; // If we are listening on the port - if rtl_dev_info.open_ports.contains(&dest_port) { + if rtl_dev_config.open_ports.contains(&dest_port) { // Try to initialize the port data structure - if !rtl_dev_info.ports.contains_key(&dest_port) { - rtl_dev_info.ports.insert(dest_port, VecDeque::new()); + if !rtl_dev_config.ports.contains_key(&dest_port) { + rtl_dev_config.ports.insert(dest_port, VecDeque::new()); } // Push back the UDP packet and wake the port - rtl_dev_info.ports.get_mut(&dest_port).unwrap().push_back(Ok(PacketData::UDP(udp))); + rtl_dev_config + .ports + .get_mut(&dest_port) + .unwrap() + .push_back(Ok(PacketData::UDP(udp))); wake_sockets(dest_port); } } @@ -197,13 +232,13 @@ pub async fn init_packet_processing() { } // TCP Packet let dest_port = tcp.dest_port.val() as u64; - if !rtl_dev_info.open_ports.contains(&dest_port) { + if !rtl_dev_config.open_ports.contains(&dest_port) { // If we aren't listening on the port, throw the packet out return; } // Try to initialize the port structure - if !rtl_dev_info.ports.contains_key(&dest_port) { - rtl_dev_info.ports.insert(dest_port, VecDeque::new()); + if !rtl_dev_config.ports.contains_key(&dest_port) { + rtl_dev_config.ports.insert(dest_port, VecDeque::new()); } // Create the session key let session_key = TCPSession::gen_session_key(tcp.ip.src_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); @@ -220,9 +255,9 @@ pub async fn init_packet_processing() { let tcp_layer = TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); // Create a session with the template let session = TCPSession::new(tcp_layer, tcp.ip.src_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); - // Lets upstream our recevied packet to the listening port (and wake it) + // Lets upstream our received packet to the listening port (and wake it) // - This is used to split the socket into an "Established" session socket and a "Listening" socket - rtl_dev_info + rtl_dev_config .ports .get_mut(&dest_port) .unwrap() @@ -239,14 +274,14 @@ pub async fn init_packet_processing() { if let Some(response) = ack_pkt.0 { // If we got a response packet to send back, send it // If no ack is received, the host will send another transmission for us to respond to - rtl_dev_info.send_packet(&response.serialize()); + NET_INFO.lock().get_mut().unwrap().send_packet(&response.serialize()); } // todo: Release TCP resources after 5 minutes if never closed? // Interpret the action from the process_recv function if ack_pkt.1 != SessionAction::Drop { // if we are listening on the session, try to init the packet queue on that end - if !rtl_dev_info.ports.contains_key(&session_key) { - rtl_dev_info.ports.insert(session_key, VecDeque::new()); + if !rtl_dev_config.ports.contains_key(&session_key) { + rtl_dev_config.ports.insert(session_key, VecDeque::new()); } let res = if ack_pkt.1 == SessionAction::PushUpstream { // If the action is upstream, then we push the packet to the raw socket to handle it's data (if present) @@ -259,15 +294,65 @@ pub async fn init_packet_processing() { unreachable!(); }; // Push back the packet or end-of-stream token to the session socket -> and then wake the socket - rtl_dev_info.ports.get_mut(&session_key).unwrap().push_back(res); + rtl_dev_config.ports.get_mut(&session_key).unwrap().push_back(res); wake_sockets(session_key); } } _ => {} // ignore other packets } })(); - // Release the guard and enable interrupts - drop(rtl_dev_info_guard); - enable_network_interrupts(); } } + + +pub async fn test() -> Result<(), String> { + mark_as_test!("Processing (pipeline stage 2)"); + + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let ip = IPPacket::gen(eth, UDPPacket::packet_size() + 10, Protocol::Udp, 3, 4); + let mut udp = UDPPacket::gen(ip.clone(), 5, 4444, 10); + udp.data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + udp.ip.calculate_checksum(); + udp.calculate_checksum(); + let mut udp_without_checksum = UDPPacket::gen(ip, 6, 4444, 10); + udp_without_checksum.data = udp.data.clone(); + + let mut socket = RawSocket::new(4444, 10).unwrap(); + + let good = udp.serialize(); + let bad = udp_without_checksum.serialize(); + + // Create a queue of 1000 good packets, 20 bad packets + for _ in 0..50 { + for _ in 0..19 { + local_send_pkt(&good); + } + local_send_pkt(&bad); + serial_print!("-"); + // Process the packets until we run out + let mut i = 0; + while i != 19 { + serial_print!("."); + if let Some(pkt_or_err) = socket.next().await { + // Check timeouts + if let Err(err) = pkt_or_err { + check!(err == NetworkErrors::Timeout, "Timeout could happen once we run out of packets"); + break; + } + let pkt = pkt_or_err.unwrap(); + check!(pkt.get_type() == LayerType::Udp, "Checking packet field"); + let udp = pkt.unwrap_udp(); + check!(udp.src_port.val() == 5, "Checking src port to be 5, a way we distinguish good packets"); + check!(udp.data[6] == 6, "Checking data seems ok"); + i += 1; + } else { + // Timeout happened, break... + check!(false, "Timeout was not supposed to happen"); + break; + } + } + check!(i == 19, "Test RawSocket receiving correct amount from pipeline"); + } + socket.close(); + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/raw_array.rs b/kernel/src/network/raw_array.rs index dc2cddb..e502fbf 100644 --- a/kernel/src/network/raw_array.rs +++ b/kernel/src/network/raw_array.rs @@ -1,3 +1,7 @@ +use alloc::string::String; + +use crate::{mark_as_test, check, test_ok, serial_print, serial_println}; + /// An array to represent the buffer of the RTL8139 /// (will wrap it's array accesses as required by the data-sheet) pub struct WrappingRawArray { @@ -56,3 +60,46 @@ impl WrappingRawArray { } } + +pub fn test() -> Result<(), String> { + mark_as_test!("RawArray"); + let mut data: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let mut raw_array = WrappingRawArray::new(data.as_ptr(), 10); + for i in 0..10 { + check!(raw_array.get_next_u8() == i, "Checking values on first loop around"); + } + for i in 0..10 { + check!(raw_array.get_next_u8() == i, "Checking values on second loop around"); + } + raw_array.shift_amount(10); + for i in 0..10 { + check!(raw_array.get_next_u8() == i, "Checking values on 4th loop around, after shifting"); + } + raw_array.shift_amount(5); + for i in 0..5 { + check!(raw_array.get_next_u8() == i + 5, "Checking values on 5th loop around, 5 ahead"); + } + check!(raw_array.get_next_u16() == 0x0100, "Next u16 (1)"); + check!(raw_array.get_next_u16() == 0x0302, "Next u16 (2)"); + check!(raw_array.get_next_u16() == 0x0504, "Next u16 (3)"); + check!(raw_array.get_next_u16() == 0x0706, "Next u16 (4)"); + check!(raw_array.get_next_u16() == 0x0908, "Next u16 (5)"); + let mut read_data: [u8; 1500] = [0; 1500]; + raw_array.trim(&mut read_data, 9); + for i in 0..9 { + check!(read_data[i] == data[i], "First 9 equal the data"); + data[i] = 0; + } + // ignore last number + raw_array.shift_amount(1); + for i in 0..10 { + let next_num = raw_array.get_next_u8(); + if i == 9 { + check!(next_num == 9, "Didn't change last number"); + } else { + check!(next_num == 0, "Data backing the array should have been modified"); + } + } + + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index 5cca964..8ff1733 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -1,12 +1,19 @@ -use futures_util::{task::AtomicWaker, Stream}; +use alloc::{string::String, vec}; +use futures_util::{task::AtomicWaker, Stream, StreamExt}; use hashbrown::HashMap; -use crate::task::timeout::{cancel_timeout, register_timeout, TimeoutID}; - -use super::{ - layer::PacketData, - rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, errors::NetworkErrors, +use crate::{ + network::{ + ethernet::{EthType, EthernetPacket}, + ip::{IPPacket, Protocol}, + layer::{HasChecksum, Layer, LayerType}, + udp::UDPPacket, processing::local_send_pkt, + }, + task::timeout::{cancel_timeout, register_timeout, TimeoutID}, + mark_as_test, check, serial_print, serial_println, test_ok, }; + +use super::{errors::NetworkErrors, layer::PacketData, rtl8139::NET_INFO}; use core::task::{Context, Poll}; use lazy_static::lazy_static; @@ -32,25 +39,18 @@ impl RawSocket { /// - will register a timeout after timeout_in_epochs /// - will return itself or a network error if cannot bind to the port pub fn new(port: u64, timeout_in_epochs: u16) -> Result { - // Acquire the driver guard - disable_network_interrupts(); - let mut rtl_dev_info_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // Acquire the network stack object + let mut rtl_dev_config = NET_INFO.config.lock(); // Check if the port is in use - if rtl_dev_info.open_ports.contains(&port) { - enable_network_interrupts(); + if rtl_dev_config.open_ports.contains(&port) { return Err(NetworkErrors::PortInUse); } // If not then bind to it - rtl_dev_info.open_ports.insert(port); + rtl_dev_config.open_ports.insert(port); // and allocate a waker NEW_PACKET_WAKER.lock().insert(port, AtomicWaker::new()); - // Release the driver guard - drop(rtl_dev_info_guard); - enable_network_interrupts(); - // Return the raw socket's initial state Ok(RawSocket { port, @@ -62,10 +62,9 @@ impl RawSocket { /// Internal function to query for a packet fn try_get_packet_inner(&self) -> Option> { - // Acquire the driver and try pop from the queue - let mut rtl_dev_info_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); - match rtl_dev_info.ports.get_mut(&self.port) { + // Acquire the network stack object and try pop from the queue + let mut rtl_dev_config = NET_INFO.config.lock(); + match rtl_dev_config.ports.get_mut(&self.port) { Some(vec) => vec.pop_front(), None => None, } @@ -73,23 +72,19 @@ impl RawSocket { /// Close the raw socket and release the resources associated with it pub fn close(self) { - // The raw socket closes by acquiring the driver - disable_network_interrupts(); - let mut rtl_dev_info_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // The raw socket closes by acquiring the network stack + let mut rtl_dev_config = NET_INFO.config.lock(); // and closing the port so that we don't receive anymore packets - rtl_dev_info.open_ports.remove(&self.port); + rtl_dev_config.open_ports.remove(&self.port); // Remove the waker, since the port shouldn't have anymore listeners NEW_PACKET_WAKER.lock().remove(&self.port); // Try to clear all the pending packets from the port - if rtl_dev_info.ports.contains_key(&self.port) { - let vec = rtl_dev_info.ports.get_mut(&self.port); - vec.unwrap().clear(); + if let Some(vec) = rtl_dev_config.ports.get_mut(&self.port) { + vec.clear(); } // tcp session information is removed by the socket (not the raw socket) // since the socket is typed (tcp or udp) and the raw_socket purpose is just // for polling for packets - enable_network_interrupts(); } } @@ -100,9 +95,7 @@ impl Stream for RawSocket { /// Poll for a new packet on the raw socket (it's main purpose) fn poll_next(mut self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Try to get a packet - disable_network_interrupts(); let pkt = self.try_get_packet_inner(); - enable_network_interrupts(); // Then register the port with a waker let locked_waker_map = NEW_PACKET_WAKER.lock(); @@ -131,8 +124,48 @@ impl Stream for RawSocket { /// Wake sockets by port pub(crate) fn wake_sockets(port: u64) { let guard = NEW_PACKET_WAKER.lock(); - if guard.contains_key(&port){ + if guard.contains_key(&port) { // wake the port up, if possible guard[&port].wake(); } } + +pub async fn test() -> Result<(), String> { + mark_as_test!("RawSocket (pipeline stage 3)"); + let rtl_config_guard = NET_INFO.config.lock(); + let my_ip = rtl_config_guard.my_ip_address.unwrap_or(0); + let my_mac = rtl_config_guard.my_mac_address.unwrap_or(0); + drop(rtl_config_guard); + + let eth = EthernetPacket::gen(my_mac, 0, EthType::IPv4); + let ip = IPPacket::gen(eth, UDPPacket::packet_size() + 10, Protocol::Udp, 1, my_ip); + let mut udp = UDPPacket::gen(ip.clone(), 51071, 5554, 10); + udp.data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + udp.ip.calculate_checksum(); + udp.calculate_checksum(); + let data = udp.serialize(); + + let mut socket = RawSocket::new(5554, 10).unwrap(); + let next = socket.next().await; + check!(next.is_none(), "In raw-socket land, none is from a timeout"); + socket.close(); + + for _k in 0..100 { + let mut socket = RawSocket::new(5554, 10).unwrap(); + local_send_pkt(&data); + loop { + if let Some(Ok(packet)) = socket.next().await { + check!(packet.get_type() == LayerType::Udp, "Checking packet is udp only"); + let udp = packet.unwrap_udp(); + check!(udp.data.len() == 10, "Checking data is correct size"); + check!(udp.data[6] == 6, "Checking data is correct content"); + break; + } else { + // timeout, try adding the packet again + local_send_pkt(&data); + } + } + socket.close(); + } + test_ok!(); +} diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 0ff5898..20d2cea 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -17,10 +17,11 @@ use x86_64::{ use crate::interrupts::{IDT, PIC_1_OFFSET}; use crate::network::constants::{ - CAPR, CMD_REG, CMD_REG_BUFE, CMD_REG_RE, CMD_REG_RST, CMD_REG_TE, RX_BROADCAST, RX_BUFFER_SIZE, RX_BUF_REG, RX_READ_PTR_MASK, - RX_START_REG, RX_MULTICAST, RX_PHYSICAL_MATCH, RX_PROMISCUOUS, CONFIG_1_REG, + CAPR, CMD_REG, CMD_REG_BUFE, CMD_REG_RE, CMD_REG_RST, CMD_REG_TE, CONFIG_1_REG, RX_BROADCAST, RX_BUFFER_SIZE, RX_BUF_REG, RX_MULTICAST, + RX_PHYSICAL_MATCH, RX_PROMISCUOUS, RX_READ_PTR_MASK, RX_START_REG, }; use crate::network::raw_array::WrappingRawArray; +use crate::serial_println; use super::constants::{IMR_REG, INTERRUPT_MASK, ISR_REG, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; use super::errors::NetworkErrors; @@ -86,6 +87,11 @@ pub fn enable_network_interrupts() { // Note: If the counter is not 0, someone else has disabled interrupts so we cannot enable it yet } +// If network interrupts are enabled +pub fn are_network_interrupts_enabled() -> bool { + unsafe { INTERRUPTS_ARE_ENABLED.lock().get() == 0 } +} + pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) { // Try to get the device info let mut net_dev = unsafe { NET_INFO.lock_no_disable() }; @@ -94,8 +100,8 @@ pub extern "x86-interrupt" fn network_handle(_stack_frame: InterruptStackFrame) } // Get the device fields let rtl_dev_info = net_dev.as_mut().unwrap(); - let io_base = rtl_dev_info.config.io_base; - let irq = rtl_dev_info.config.irq; + let io_base = rtl_dev_info.device_info.io_base; + let irq = rtl_dev_info.device_info.irq; if io_base.is_none() || irq.is_none() { println!("[ERR] Handling packet - missing data"); unsafe { @@ -135,7 +141,7 @@ fn recv_packet(rtl_dev_info: &RTL8139) { panic!("RTL8139 is not initialized properly"); } // Make sure buffer isn't empty - let cmd_reg = (rtl_dev_info.config.io_base.unwrap() + CMD_REG) as u16; + let cmd_reg = (rtl_dev_info.device_info.io_base.unwrap() + CMD_REG) as u16; let mut cmd_port = Port::::new(cmd_reg); while unsafe { cmd_port.read() } & CMD_REG_BUFE == 0x0 { // Receive a packet by reading the buffer @@ -148,7 +154,8 @@ fn recv_packet(rtl_dev_info: &RTL8139) { // Checking receive OK and no errors if header & 0x01 != 0 && header & 0x02 == 0 && header & 0x04 == 0 && header & 0x20 == 0 { let length = rx_buffer.get_next_u16(); // get the next two bytes - if length != 0 { + if length != 0 && (length - 4) <= 1500 { + // If we have a proper let mut packet: [u8; 1500] = [0; 1500]; rx_buffer.trim(&mut packet, (length - 4) as usize); // ? throw out the crc... we don't need to check it... @@ -171,7 +178,7 @@ fn recv_packet(rtl_dev_info: &RTL8139) { unsafe { RECV_POS = ((RECV_POS + 4) & RX_READ_PTR_MASK) % RX_BUFFER_SIZE; } - let mut capr = Port::::new((rtl_dev_info.config.io_base.unwrap() + CAPR) as u16); + let mut capr = Port::::new((rtl_dev_info.device_info.io_base.unwrap() + CAPR) as u16); unsafe { capr.write(RECV_POS - 0x10) }; } else { unsafe { @@ -183,22 +190,22 @@ fn recv_packet(rtl_dev_info: &RTL8139) { } // TODO: Split the driver into separate bits so we can lock individual resources? -// ! Otherwise we will have a bottleneck? pub struct RTL8139 { - pub config: Device, + pub device_info: Device, recv_buffer: Option, // 12KB send_buffer: Option, // 12KB physical_mem_offset: Option, + pub mac_address: Option, +} + +pub struct NetworkConfig { pub my_ip_address: Option, + pub my_mac_address: Option, pub dhcp_server_ip: Option, - pub mac_address: Option, pub open_ports: HashSet, pub ports: HashMap>>, pub arp_table: Vec, pub to_expire: VecDeque, -} - -pub struct NetworkConfig { pub tcp_sessions: HashMap, } @@ -206,6 +213,13 @@ impl NetworkConfig { pub fn new() -> Self { NetworkConfig { tcp_sessions: HashMap::with_capacity(10), + my_ip_address: None, + dhcp_server_ip: None, + my_mac_address: None, + open_ports: HashSet::with_capacity(10), + ports: HashMap::with_capacity(10), + arp_table: Vec::new(), + to_expire: VecDeque::with_capacity(20), } } } @@ -264,17 +278,11 @@ impl RTL8139 { unsafe { IO_BASE = device.io_base.unwrap() as usize }; // Return the device and a 12KB physical region return Some(RTL8139 { - config: device, + device_info: device, recv_buffer: None, send_buffer: None, - my_ip_address: None, - dhcp_server_ip: None, - physical_mem_offset: None, mac_address: None, - open_ports: HashSet::with_capacity(10), - ports: HashMap::with_capacity(10), - arp_table: Vec::new(), - to_expire: VecDeque::with_capacity(20), + physical_mem_offset: None, }); } None @@ -285,15 +293,15 @@ impl RTL8139 { /// @return True on success fn setup(&mut self) -> bool { // disable the interrupt line to prevent early triggering of an interrupt before we register the handler - InterruptHandler::block_irq(self.config.irq.unwrap()); + InterruptHandler::block_irq(self.device_info.irq.unwrap()); // enable bus mastering - let mut cr = self.config.read_command_register(); + let mut cr = self.device_info.read_command_register(); cr.set_bus_master_bit(true); - self.config.write_command_register(cr); + self.device_info.write_command_register(cr); // Check bus mastering - let cr = self.config.read_command_register(); + let cr = self.device_info.read_command_register(); match cr.get_bus_master_bit() { true => println!("[INFO] Bus mastering enabled!"), false => { @@ -306,20 +314,20 @@ impl RTL8139 { false => println!("[INFO] Interrupts enabled"), } // Turning on the RTL8139 - if self.config.io_base.is_none() { + if self.device_info.io_base.is_none() { println!("[ERR] Cannot find IO-Address"); return false; } - if self.config.irq.is_none() { + if self.device_info.irq.is_none() { println!("[ERR] Cannot find IRQ"); return false; } // Register the interrupt handler for the card - IDT.lock().register_irq(self.config.irq.unwrap() as usize, network_handle); + IDT.lock().register_irq(self.device_info.irq.unwrap() as usize, network_handle); // Get MAC address - let mac_addr = self.config.io_base.unwrap(); // + 0 offset + let mac_addr = self.device_info.io_base.unwrap(); // + 0 offset let mut mac1 = Port::::new(mac_addr as u16); let mut mac2 = Port::::new((mac_addr + 0x04) as u16); let mut mac_addr = 0x0_u64; @@ -328,16 +336,17 @@ impl RTL8139 { mac_addr |= (mac2.read() as u64) << 32; }; // save mac address + NET_INFO.config.lock().my_mac_address = Some(mac_addr); self.mac_address = Some(mac_addr); println!("[INFO] MAC address is {:#10x}", mac_addr); // turn on the card by setting config 1 to 0x00 - let config_1_reg = self.config.io_base.unwrap() as u16 + CONFIG_1_REG; + let config_1_reg = self.device_info.io_base.unwrap() as u16 + CONFIG_1_REG; let mut port_config_1 = Port::::new(config_1_reg); unsafe { port_config_1.write(0x00) }; // Performing software reset - let cmd_reg = self.config.io_base.unwrap() + CMD_REG; + let cmd_reg = self.device_info.io_base.unwrap() + CMD_REG; let mut port_rst = Port::::new(cmd_reg as u16); unsafe { port_rst.write(CMD_REG_RST); @@ -350,7 +359,7 @@ impl RTL8139 { unsafe { port_recv_transmit.write(CMD_REG_TE | CMD_REG_RE) }; // Configuring receive buffer - let rcr_reg = self.config.io_base.unwrap() as u16 + RX_BUF_REG; + let rcr_reg = self.device_info.io_base.unwrap() as u16 + RX_BUF_REG; let mut rcr = Port::::new(rcr_reg); unsafe { // No wrap, ring buffer, accept all packets @@ -358,21 +367,23 @@ impl RTL8139 { }; // Init receive buffer - let rcv_buf_reg = self.config.io_base.unwrap() as u16 + RX_START_REG; + let rcv_buf_reg = self.device_info.io_base.unwrap() as u16 + RX_START_REG; let mut rcv_buffer = Port::::new(rcv_buf_reg); unsafe { rcv_buffer.write(self.recv_buffer.unwrap().as_u64() as u32) }; // Set IMR + ISR - let imr_reg = self.config.io_base.unwrap() as u16 + IMR_REG; + let imr_reg = self.device_info.io_base.unwrap() as u16 + IMR_REG; let mut imr = Port::::new(imr_reg); unsafe { imr.write(INTERRUPT_MASK) }; // Enable interrupts // How I understood why: https://forum.osdev.org/viewtopic.php?f=1&t=27901 - InterruptHandler::unblock_irq(self.config.irq.unwrap()); + InterruptHandler::unblock_irq(self.device_info.irq.unwrap()); true } + // todo: change to borrowed slice && make all enums derive debug + // ! ^^ /// Write the packet data to the card and notify /// This will send the packet pub fn send_packet(&self, packet_data: &Vec) { @@ -381,7 +392,7 @@ impl RTL8139 { panic!("RTL8139 is not initialized properly"); } // If we don't have a mac address, stop - let io_base = self.config.io_base; + let io_base = self.device_info.io_base; if self.mac_address.is_none() || io_base.is_none() { return; } @@ -417,3 +428,6 @@ impl RTL8139 { }; } } + +// Note: to actually test the network driver, we need to put packets on the network +// For this, I wrote a python script instead of an in-OS test. diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 7e9fe7e..2dbbf71 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -1,9 +1,11 @@ use core::cmp::{max, min}; -use alloc::{vec, boxed::Box}; +use alloc::string::String; use alloc::vec::Vec; +use alloc::{boxed::Box, vec}; use futures_util::StreamExt; +use crate::{test_ok, mark_as_test, check, serial_print, serial_println}; use crate::{network::layer::LayerType, println}; use super::{ @@ -14,7 +16,7 @@ use super::{ layer::{HasChecksum, Layer, PacketData}, network_query::NetworkQuery, raw_socket::RawSocket, - rtl8139::{disable_network_interrupts, enable_network_interrupts, NET_INFO}, + rtl8139::NET_INFO, tcp_session::TCPSession, udp::UDPPacket, }; @@ -60,17 +62,15 @@ impl Socket { // Start the chosen src port as src_port let mut chosen_src_port = src_port; - // Acquire the network driver - disable_network_interrupts(); - let rtl_dev_info_locked = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_locked.get_ref().unwrap(); + // Acquire the network stack object + let rtl_dev_config = NET_INFO.config.lock(); // get the src mac and address from the driver - let src_mac = rtl_dev_info.mac_address.unwrap(); - let src_address = rtl_dev_info.my_ip_address.unwrap(); + let src_mac = rtl_dev_config.my_mac_address.unwrap(); + let src_address = rtl_dev_config.my_ip_address.unwrap(); // If our source port is 0 if src_port == 0 { - let open_ports = &rtl_dev_info.open_ports; + let open_ports = &rtl_dev_config.open_ports; // Linear probe to see if any ports are open for i in 1000..u16::MAX { if !open_ports.contains(&(i as u64)) { @@ -80,9 +80,8 @@ impl Socket { } } } - // Release the driver code - drop(rtl_dev_info_locked); - enable_network_interrupts(); + // Release the driver once we are done + drop(rtl_dev_config); // If our chosen src port is 0, we have no available port error if chosen_src_port == 0 { @@ -144,20 +143,15 @@ impl Socket { // and transition into ready state self.socket_state = SocketState::Ready; - // acquire the driver - disable_network_interrupts(); - let mut rtl_dev_info_guard = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_guard.get_mut().unwrap(); + // acquire the network stack object + let mut rtl_dev_config = NET_INFO.config.lock(); // re-enqueue the packet - if let Some(vec) = rtl_dev_info.ports.get_mut(&(self.src_port as u64)) { + if let Some(vec) = rtl_dev_config.ports.get_mut(&(self.src_port as u64)) { vec.push_front(Ok(PacketData::UDP(Box::new(udp_pkt)))); } - // release the driver - drop(rtl_dev_info_guard); - enable_network_interrupts(); - // and return itself + // and return "itself" return Ok(None); } else if pkt.get_type() == LayerType::Tcp && self.socket_type == SocketType::TCP { // If we have a match (TCP <--> TCP) @@ -203,18 +197,16 @@ impl Socket { // Get the chosen port let mut chosen_src_port = src_port; - // Acquire the driver - disable_network_interrupts(); - let rtl_dev_info_locked = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_locked.get_ref().unwrap(); + // Acquire the network stack object + let rtl_dev_config = NET_INFO.config.lock(); // Extract src mac and src address - let src_mac = rtl_dev_info.mac_address.unwrap(); - let src_address = rtl_dev_info.my_ip_address.unwrap(); + let src_mac = rtl_dev_config.my_mac_address.unwrap(); + let src_address = rtl_dev_config.my_ip_address.unwrap(); // if src port is 0 if src_port == 0 { - let open_ports = &rtl_dev_info.open_ports; + let open_ports = &rtl_dev_config.open_ports; // Linear probe to see if any ports are open for i in 1000..u16::MAX { if !open_ports.contains(&(i as u64)) { @@ -226,8 +218,7 @@ impl Socket { } // Release the driver - drop(rtl_dev_info_locked); - enable_network_interrupts(); + drop(rtl_dev_config); // Also query for a destination mac address let dest_mac = NetworkQuery::get_mac_from_ip(10, dest_address).await; @@ -289,12 +280,10 @@ impl Socket { } // Send the FIN-ACK packet - disable_network_interrupts(); let mut rtl_dev_info_locked = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); rtl_dev_info.send_packet(&pkt.serialize()); drop(rtl_dev_info_locked); - enable_network_interrupts(); loop { // Keep reading the stream until we get an error @@ -333,22 +322,28 @@ impl Socket { /// Internal function for reading as a UDP socket #[allow(clippy::question_mark)] async fn read_udp(&mut self, size: usize) -> Result, NetworkErrors> { - // todo: fix reading loop { - // Spin until we get some packet + // Create a result vector + let mut res_vec = vec![]; + // Loop until we get a packet if let Some(pkt_or_err) = self.raw_socket.next().await { - // If we poll an error, we pass it to the read as an error + // If the packet is an error, read_udp returns an error if let Err(err) = pkt_or_err { return Err(err); } let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::Udp { - // If we have a matching UDP packet, we pass the data to the read result - let udp_pkt = pkt.unwrap_udp(); - return Ok(udp_pkt.data); + // If we have a matching UDP packet, unwrap it + let mut udp_pkt = pkt.unwrap_udp(); + // Save the packet into our buffer + res_vec.append(&mut udp_pkt.data); + } + // Our size has exceeded the request, so we can return the resulting data + if size <= res_vec.len() { + return Ok(res_vec); } } else { - // We got a timeout and we return for UDP when this happens + // Our socket timed-out reading, so we return the timeout error return Err(NetworkErrors::Timeout); } } @@ -422,11 +417,14 @@ impl Socket { let eth_layer = EthernetPacket::gen(self.src_mac, self.dest_mac, ethernet::EthType::IPv4); let udp_size = UDPPacket::packet_size() + data.len() as u16; let ip_layer = IPPacket::gen(eth_layer, udp_size, Protocol::Udp, self.src_address, self.dest_ip); - let data_len = min(data.len(), 60000); + let data_len = min(data.len(), 1380); let mut udp_layer = UDPPacket::gen(ip_layer, self.src_port, self.dest_port, data_len as u16); udp_layer.data = data[..data_len].to_vec(); let mut index = 0; - data.retain(|_| {index += 1; index > data_len}); + data.retain(|_| { + index += 1; + index > data_len + }); udp_layer.ip.calculate_checksum(); udp_layer.calculate_checksum(); @@ -434,9 +432,7 @@ impl Socket { let packet_data = udp_layer.serialize(); // Send the packet - disable_network_interrupts(); NET_INFO.lock().get_ref().unwrap().send_packet(&packet_data); - enable_network_interrupts(); // Return how much was written Ok(data_len as u16) @@ -454,12 +450,15 @@ impl Socket { // Set up to receive ack by processing the send on the data let mut tcp_session_guard = NET_INFO.config.lock(); let tcp_session = tcp_session_guard.tcp_sessions.get_mut(&self.session_key).unwrap(); - let data_len = min(data.len(), tcp_session.window_size as usize); + let data_len = min(data.len(), min(tcp_session.window_size as usize, 1380)); // Message pkt is our present to send to our server let message_pkt = tcp_session.process_send(&data[..data_len]); // Trim the array to remove the sent data let mut index = 0; - data.retain(|_| {index += 1; index > data_len}); + data.retain(|_| { + index += 1; + index > data_len + }); // Do error checking if let Err(err) = message_pkt { return Err(err); @@ -471,14 +470,12 @@ impl Socket { // Wait for the ack -- 20 retries for retries in 1..21 { // Acquire the driver - disable_network_interrupts(); let mut rtl_dev_info_locked = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); // Send the packet rtl_dev_info.send_packet(&message); // Release the driver drop(rtl_dev_info_locked); - enable_network_interrupts(); // Get next packet or timeout if let Some(pkt_or_err) = self.raw_socket.next().await { @@ -507,3 +504,12 @@ impl Socket { Ok(data_len as u16) } } + + + +pub async fn test() -> Result<(), String> { + mark_as_test!("Socket"); + check!(true, "Yay!"); + serial_println!("todo: this test file, using local_send"); + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index 0052afd..5562cf2 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -98,7 +98,7 @@ impl TCPPacket { /// N.B. Setting operations swap endianness because when we use val, we are in network-order /// Turn on the flags provided pub fn turn_on_flags(&mut self, flags: u8) { - let new_flags = self.flags.swapped_endianness().val() | (flags as u16); + let new_flags = self.flags.swapped().val() | (flags as u16); self.flags = Bytefield16::new(new_flags); } @@ -122,7 +122,10 @@ impl Layer for TCPPacket { { // Create an empty packet let mut packet = TCPPacket::new(); - + // Check bytevec size + if bytevec.len() < Self::packet_size() as usize { + return (packet, 0, LayerType::Err) + } // Save ip packet and read 14 bytes let mut i = 0; packet.ip = ip_layer; @@ -134,6 +137,10 @@ impl Layer for TCPPacket { packet.sliding_window = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.urgent = Bytefield16::read_inc(&bytevec[i..], &mut i); + // Check bytevec size + if (bytevec.len() as u8) < packet.get_header_offset() { + return (packet, 0, LayerType::Err) + } // read remaining bytes of header and place them into the options buffer for _ in 0..(packet.get_header_offset() - 20) { packet.options.push(bytevec[i]); @@ -143,6 +150,10 @@ impl Layer for TCPPacket { assert!(i == packet.get_header_offset().into()); // Get the data size let data_size = packet.ip.total_length.val() - packet.get_header_offset() as u16 - IPPacket::packet_size(); + // Check bytevec size + if (bytevec.len() as u16) < packet.get_header_offset() as u16 + data_size { + return (packet, 0, LayerType::Err) + } // Read everything else and save it into data buffer for _ in 0..data_size { packet.data.push(bytevec[i]); @@ -216,31 +227,31 @@ impl HasChecksum for TCPPacket { eth: EthernetPacket::new(), version_hlen: self.ip.version_hlen, type_of_service: self.ip.type_of_service, - total_length: self.ip.total_length.swapped_endianness(), - identification: self.ip.identification.swapped_endianness(), - flags_fragment_offset: self.ip.flags_fragment_offset.swapped_endianness(), + total_length: self.ip.total_length.swapped(), + identification: self.ip.identification.swapped(), + flags_fragment_offset: self.ip.flags_fragment_offset.swapped(), ttl: self.ip.ttl, protocol: self.ip.protocol, checksum: Bytefield16::new(0), - src_ip: self.ip.src_ip.swapped_endianness(), - dest_ip: self.ip.dest_ip.swapped_endianness(), + src_ip: self.ip.src_ip.swapped(), + dest_ip: self.ip.dest_ip.swapped(), }; ip.calculate_checksum(); - if self.ip.checksum.swapped_endianness().val() != ip.checksum.val() { + if self.ip.checksum.swapped().val() != ip.checksum.val() { return false; } // Clone the packet in host order let mut tcp: TCPPacket = TCPPacket { ip, - src_port: self.src_port.swapped_endianness(), - dest_port: self.dest_port.swapped_endianness(), - seq_num: self.seq_num.swapped_endianness(), - ack_num: self.ack_num.swapped_endianness(), - flags: self.flags.swapped_endianness(), - sliding_window: self.sliding_window.swapped_endianness(), + src_port: self.src_port.swapped(), + dest_port: self.dest_port.swapped(), + seq_num: self.seq_num.swapped(), + ack_num: self.ack_num.swapped(), + flags: self.flags.swapped(), + sliding_window: self.sliding_window.swapped(), checksum: Bytefield16::new(0), - urgent: self.urgent.swapped_endianness(), + urgent: self.urgent.swapped(), options: self.options.clone(), data: self.data.clone(), }; @@ -258,12 +269,12 @@ impl HasChecksum for TCPPacket { tcp.calculate_checksum(); // println!("CHECKSUMS: {:#10x} {:#10x}", tcp.checksum.val(), self.checksum.swapped_endianness().val()); // println!("CHECKSUMS: {} {}", tcp.checksum.val(), self.checksum.swapped_endianness().val()); - // todo: figure out why I am off by 1024 when data is empty. This is a terrible way to fix a bug (and doesn't always works) + // todo: figure out why I am off by 1024 when data is empty. This is a terrible way to fix a bug by ignoring this edge case if tcp.data.is_empty() { true // tcp.checksum.val() - 1024 == self.checksum.swapped_endianness().val() } else { - tcp.checksum.val() == self.checksum.swapped_endianness().val() + tcp.checksum.val() == self.checksum.swapped().val() } } } diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index d208f47..39a94a3 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -6,7 +6,7 @@ use super::{ errors::NetworkErrors, ip::IPPacket, layer::{HasChecksum, Layer}, - tcp::TCPPacket, + tcp::TCPPacket, processing::get_window_size, }; /// An enum for different TCP session states @@ -51,9 +51,7 @@ pub struct TCPSession { pub sent_data_acked: u32, /// our ack num pub recv_data_amount: u32, - // pub recv_data_acked: u32, implicit how much we have acked -- this value is unknown to user - // todo: implement a use for window size - /// Window size is the window size of the last packet we received + /// Window size is the window size of the last packet we received (this is how we are doing congestion control) pub window_size: u16, /// If the user has sent fin_ack closing has_sent_fin_ack: bool, @@ -78,8 +76,9 @@ impl TCPSession { dest_ip, dest_port, src_port, - // we max our window size because we want data ASAP - window_size: u16::MAX, + // we set our window size to size of processing buffer. + // after that, it is heap-allocated so we can have as much pending data as we need + window_size: 1500 * get_window_size(), has_sent_fin_ack: false, has_recv_ack_to_fin_ack: false, has_sent_ack_to_fin_ack: false, diff --git a/kernel/src/network/test.rs b/kernel/src/network/test.rs new file mode 100644 index 0000000..f743711 --- /dev/null +++ b/kernel/src/network/test.rs @@ -0,0 +1,73 @@ +use alloc::string::String; + +use crate::{serial_println, QemuExitCode, exit_qemu}; + +use super::{arp_table, arp, bytefield, command_register, devices, dhcp, ethernet, ip, layer, netsync, network_query, processing, raw_array, raw_socket, udp}; + +// Normal sync tests (run before starting async processing) +pub fn test_sync() -> Result<(), String> { + arp_table::test()?; + arp::test()?; + bytefield::test()?; + command_register::test()?; + devices::test()?; + dhcp::test()?; + ethernet::test()?; + ip::test()?; + layer::test()?; + netsync::test()?; + raw_array::test()?; + udp::test()?; + Ok(()) +} + +// Async tests (run after IP is found) +pub async fn test_async() { + // Test network query + if let Err(err) = network_query::test().await { + serial_println!("[ERR] {}", err); + exit_qemu(QemuExitCode::Failed); + } + + // Test processing (stage 2) + if let Err(err) = processing::test().await { + serial_println!("[ERR] {}", err); + exit_qemu(QemuExitCode::Failed); + } + + // Test processing (stage 3) + if let Err(err) = raw_socket::test().await { + serial_println!("[ERR] {}", err); + exit_qemu(QemuExitCode::Failed); + } + +} + +#[macro_export] +macro_rules! check { + ($status: expr, $err: expr) => { + if !$status { + serial_println!("\nfile: {}:{} \x1b[91m[failed]\x1b[0m\n", file!(), line!()); + return Err(String::from($err)); + } // else { + // serial_println!("\t{}, file: {}:{} \x1b[92m[ok]\x1b[0m", $err, file!(), line!()); + // } + }; +} + +#[macro_export] +macro_rules! test_ok { + () => { + serial_println!("\x1b[92m[ok]\x1b[0m"); + return Ok(()); + }; +} + +#[macro_export] +macro_rules! mark_as_test { + ($description: expr) => { + let to_print = format_args!("Running {} test {}:{}", $description, file!(), line!()).as_str(); + let as_string = to_print.unwrap_or(""); + serial_print!("{: <80}", as_string); + }; +} \ No newline at end of file diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index 383fe2c..8577bf3 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -1,3 +1,9 @@ +use crate::{ + check, mark_as_test, + network::{ethernet::EthType, ip::Protocol, layer::full_parse}, + serial_print, serial_println, test_ok, +}; + use super::{ bytefield::Bytefield16, ethernet::EthernetPacket, @@ -5,8 +11,8 @@ use super::{ layer::{calculate_checksum_inner, HasChecksum, Layer, LayerType}, }; -use alloc::vec; use alloc::vec::Vec; +use alloc::{string::String, vec}; /// A UDP packet, implements Layer and HasChecksum (8 bytes) #[derive(Debug)] @@ -69,8 +75,11 @@ impl Layer for UDPPacket { { // create an empty packet let mut packet = UDPPacket::new(); - - // Save ip packet and read 14 bytes + // Check bytevec size + if bytevec.len() < Self::packet_size() as usize { + return (packet, 0, LayerType::Err); + } + // Save ip packet and read 8 bytes let mut i = 0; packet.ip = ip_layer; packet.src_port = Bytefield16::read_inc(&bytevec[i..], &mut i); @@ -79,11 +88,15 @@ impl Layer for UDPPacket { packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); // assert 8 bytes assert!(i == 8); - // Match the destionation port to see if its DHCP + // Match the destination port to see if its DHCP let layer_type = match packet.dest_port.val() { // If port 68, send to DHCP layer 68 => LayerType::Dhcp, _ => { + // Check bytevec size + if bytevec.len() < packet.length.val() as usize { + return (packet, 0, LayerType::Err); + } // read remaining bytes and place them into the data buffer for _ in 0..(packet.length.val() - 8) { packet.data.push(bytevec[i]); @@ -125,7 +138,7 @@ impl HasChecksum for UDPPacket { // Starting vars let mut sum: u32 = 0; - // First we do the IP as a pseduo header + // First we do the IP as a pseudo header let ip = &self.ip; sum += (ip.src_ip.data[0] as u32) | (ip.src_ip.data[1] as u32) << 8; sum += (ip.src_ip.data[2] as u32) | (ip.src_ip.data[3] as u32) << 8; @@ -154,29 +167,82 @@ impl HasChecksum for UDPPacket { eth: EthernetPacket::new(), version_hlen: self.ip.version_hlen, type_of_service: self.ip.type_of_service, - total_length: self.ip.total_length.swapped_endianness(), - identification: self.ip.identification.swapped_endianness(), - flags_fragment_offset: self.ip.flags_fragment_offset.swapped_endianness(), + total_length: self.ip.total_length.swapped(), + identification: self.ip.identification.swapped(), + flags_fragment_offset: self.ip.flags_fragment_offset.swapped(), ttl: self.ip.ttl, protocol: self.ip.protocol, checksum: Bytefield16::new(0), - src_ip: self.ip.src_ip.swapped_endianness(), - dest_ip: self.ip.dest_ip.swapped_endianness(), + src_ip: self.ip.src_ip.swapped(), + dest_ip: self.ip.dest_ip.swapped(), }; ip.calculate_checksum(); - if self.ip.checksum.swapped_endianness().val() != ip.checksum.val() { + if self.ip.checksum.swapped().val() != ip.checksum.val() { return false; } let mut udp: UDPPacket = UDPPacket { ip, - src_port: self.src_port.swapped_endianness(), - dest_port: self.dest_port.swapped_endianness(), - length: self.length.swapped_endianness(), + src_port: self.src_port.swapped(), + dest_port: self.dest_port.swapped(), + length: self.length.swapped(), checksum: Bytefield16::new(0), data: self.data.clone(), }; udp.calculate_checksum(); - udp.checksum.val() == self.checksum.swapped_endianness().val() + udp.checksum.val() == self.checksum.swapped().val() } } + +pub fn test() -> Result<(), String> { + mark_as_test!("UDP Packet"); + + // Create an udp packet, check it serializes correctly + let pkt: UDPPacket = UDPPacket::new(); + check!( + pkt.serialize().len() == (EthernetPacket::packet_size() + IPPacket::packet_size() + UDPPacket::packet_size()) as usize, + "Check serialization size" + ); + + // Create another UDP packet + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let ip = IPPacket::gen(eth.clone(), 0, Protocol::Udp, 3, 4); + let mut udp = UDPPacket::gen(ip, 5, 6, 7); + udp.data = vec![1, 2, 3, 4, 5, 6, 7]; + let payload_size = udp.data.len() as u16; + check!(udp.src_port.swapped().val() == 5, "Check src port"); + check!(udp.dest_port.swapped().val() == 6, "Check dest port"); + check!( + udp.length.swapped().val() == UDPPacket::packet_size() + payload_size, + "Check length" + ); + + // Serialize and deserialize + let mut serialized = udp.serialize(); + let (count, should_be_udp) = full_parse(&serialized); + check!(should_be_udp.get_type() == LayerType::Udp, "Check it's a UDP packet"); + let udp_pkt = should_be_udp.unwrap_udp(); + check!( + count as u16 == UDPPacket::packet_size() + IPPacket::packet_size() + EthernetPacket::packet_size() + payload_size, + "parse size" + ); + check!(udp.src_port.swapped() == udp_pkt.src_port, "Check field is same"); + check!(udp.dest_port.swapped() == udp_pkt.dest_port, "Check field is same"); + check!(udp.length.swapped() == udp_pkt.length, "Check field is same"); + check!(udp.data == udp_pkt.data, "Check data is same"); + + // Check parse on smaller vector than expected should also return an error + serialized.pop(); + let (_, should_be_err) = full_parse(&serialized); + check!(should_be_err.get_type() == LayerType::Err, "UDP packet has less data than promised"); + + // Create another UDP packet + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let ip = IPPacket::gen(eth.clone(), 0, Protocol::Udp, 3, 4); + let udp = UDPPacket::gen(ip, 5, 6, 0); + let mut serialized = udp.serialize(); + serialized.pop(); + let (_, should_be_err) = full_parse(&serialized); + check!(should_be_err.get_type() == LayerType::Err, "UDP packet has less header"); + test_ok!(); +} diff --git a/kernel/src/task/executor.rs b/kernel/src/task/executor.rs index 03a7fbe..499943e 100644 --- a/kernel/src/task/executor.rs +++ b/kernel/src/task/executor.rs @@ -1,5 +1,5 @@ use super::{Task, TaskId}; -use alloc::{collections::BTreeMap, sync::Arc, task::Wake}; +use alloc::{collections::{BTreeMap, BTreeSet}, sync::Arc, task::Wake}; use core::task::{Context, Poll, Waker}; use crossbeam_queue::ArrayQueue; use x86_64::instructions::interrupts::{self, enable_and_hlt}; @@ -11,6 +11,8 @@ pub struct Executor { task_queue: Arc>, /// A tree of wakers to wake tasks waker_cache: BTreeMap, + /// A reap-able list of task ids + reap_list: BTreeSet, } impl Executor { @@ -20,16 +22,18 @@ impl Executor { tasks: BTreeMap::new(), task_queue: Arc::new(ArrayQueue::new(100)), waker_cache: BTreeMap::new(), + reap_list: BTreeSet::new(), } } /// Spawn a new task - pub fn spawn(&mut self, task: Task) { + pub fn spawn(&mut self, task: Task) -> TaskId { let task_id = task.id; if self.tasks.insert(task.id, task).is_some() { panic!("task with same ID already in tasks"); } self.task_queue.push(task_id).expect("queue full"); + task_id } /// Run the executor forever @@ -41,15 +45,14 @@ impl Executor { } /// Wait until all tasks finish - pub fn wait(&mut self) { + pub fn wait(&mut self, task_id: TaskId) { loop { self.run_ready_tasks(); interrupts::disable(); - if self.task_queue.is_empty() { + if self.reap_list.remove(&task_id) { return; - } else { - interrupts::enable(); } + interrupts::enable(); } } @@ -70,12 +73,17 @@ impl Executor { tasks, task_queue, waker_cache, + reap_list, } = self; while let Some(task_id) = task_queue.pop() { let task = match tasks.get_mut(&task_id) { Some(task) => task, - None => continue, // task no longer exists + None => { + // task no longer exists. Push into reap list + reap_list.insert(task_id); + continue; + }, }; let waker = waker_cache .entry(task_id) @@ -86,6 +94,8 @@ impl Executor { // task done -> remove it and its cached waker tasks.remove(&task_id); waker_cache.remove(&task_id); + // and push its id into the reap list + reap_list.insert(task_id); } Poll::Pending => {} } diff --git a/kernel/src/task/mod.rs b/kernel/src/task/mod.rs index ea13200..c599934 100644 --- a/kernel/src/task/mod.rs +++ b/kernel/src/task/mod.rs @@ -11,7 +11,7 @@ pub mod timeout; pub mod udp_echo; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -struct TaskId(u64); +pub struct TaskId(u64); pub struct Task { id: TaskId, diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index 64ea025..b13cd63 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -42,7 +42,7 @@ pub async fn udp_echo_server() { if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { // If exit or quit, we close the socket println!("Closing socket"); - let _ = socket.write(&mut ("Closing socket...".as_bytes().to_vec())).await; + let _ = socket.write(&mut ("Closing socket...\n".as_bytes().to_vec())).await; socket.close().await; return; } diff --git a/src/main.rs b/src/main.rs index 41a4356..d0572f2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,7 +21,10 @@ fn main() { cmd.arg("-device").arg("rtl8139,netdev=net0,mac=00:11:22:33:44:55"); // Recording the network traffic in order to conduct debugging and analysis cmd.arg("-object").arg("filter-dump,id=f1,netdev=net0,file=/tmp/dump.pcap"); - + // Turn on debug output + cmd.arg("-serial").arg("stdio"); + cmd.arg("-device").arg("isa-debug-exit,iobase=0xf4,iosize=0x04"); + if uefi { cmd.arg("-bios").arg(ovmf_prebuilt::ovmf_pure_efi()); cmd.arg("-drive").arg(format!("format=raw,file={uefi_path}")); From 772832d444530ad52912e6aa666104d2098bbc25 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sat, 9 Dec 2023 17:30:22 -0500 Subject: [PATCH 29/36] Stable sockets --- kernel/src/network/README.md | 2 +- kernel/src/network/dhcp.rs | 2 +- kernel/src/network/errors.rs | 2 +- kernel/src/network/init.rs | 7 +- kernel/src/network/ip.rs | 9 +- kernel/src/network/layer.rs | 9 +- kernel/src/network/netsync.rs | 4 +- kernel/src/network/network_query.rs | 8 +- kernel/src/network/processing.rs | 49 +++-- kernel/src/network/raw_socket.rs | 5 +- kernel/src/network/rtl8139.rs | 56 +++++- kernel/src/network/socket.rs | 211 +++++++++++++++----- kernel/src/network/tcp.rs | 160 +++++++++++++++- kernel/src/network/tcp_session.rs | 288 ++++++++++++++++++++++++++-- kernel/src/network/test.rs | 37 ++-- kernel/src/network/udp.rs | 13 +- kernel/src/task/tcp_echo.rs | 2 +- kernel/src/task/timeout.rs | 65 +++---- 18 files changed, 766 insertions(+), 163 deletions(-) diff --git a/kernel/src/network/README.md b/kernel/src/network/README.md index ee55e55..f54c289 100644 --- a/kernel/src/network/README.md +++ b/kernel/src/network/README.md @@ -6,7 +6,7 @@ In mod.rs, there are only 4 public facing files. These are the network driver, i Firstly, we initialize the card. This is done by scanning for PCI devices with devices.rs. In this file, we read from the PCI configuration space to get information about the device. Some things that are important are the irq number which is necessary to register interrupts or the io_base which helps us connect to the device over ports. These values are mainly used in rtl8139.rs. This is our device driver. -The device driver serves two purposes (receive data from the card and send data to the card). It abstracts away any packet processing to other files (discussed later). It also provides an API for sending a byte vector to the card. In netsync.rs, we have a SafeRTL8139 object and a network interrupts guard. This handles synchornization. Since interrupts can come at any time, we need to properly block interrupts so we don't deadlock with our synchronization mechanisms. Acquiring the network interrupts guard handles a short critical section for the driver data. In addition our SafeRTL object wraps the locking mechanism so we can acquire/release both the driver data lock and the network interrupts lock in the correct order. +The device driver serves two purposes (receive data from the card and send data to the card). It abstracts away any packet processing to other files (discussed later). It also provides an API for sending a byte vector to the card. In netsync.rs, we have a SafeRTL8139 object and a network interrupts guard. This handles synchronization. Since interrupts can come at any time, we need to properly block interrupts so we don't deadlock with our synchronization mechanisms. Acquiring the network interrupts guard handles a short critical section for the driver data. In addition our SafeRTL object wraps the locking mechanism so we can acquire/release both the driver data lock and the network interrupts lock in the correct order. ## Layer & HasChecksum diff --git a/kernel/src/network/dhcp.rs b/kernel/src/network/dhcp.rs index 663aa09..4a850a0 100644 --- a/kernel/src/network/dhcp.rs +++ b/kernel/src/network/dhcp.rs @@ -229,7 +229,7 @@ pub fn test() -> Result<(), String> { // Create a dhcp packet use crate::network::ip::Protocol; let eth = EthernetPacket::gen(1, 2, EthType::IPv4); - let ip_size = pkt.serialize().len() as u16 - EthernetPacket::packet_size(); + let ip_size = pkt.serialize().len() as u16 - EthernetPacket::packet_size() - IPPacket::packet_size(); let ip = IPPacket::gen(eth, ip_size, Protocol::Udp, 3, 4); let udp = UDPPacket::gen(ip, DHCP_SERVER_PORT, DHCP_CLIENT_PORT, DHCPPacket::packet_size()); let dhcp = DHCPPacket::gen(udp, None, 0x010203040506); diff --git a/kernel/src/network/errors.rs b/kernel/src/network/errors.rs index adff997..8021b99 100644 --- a/kernel/src/network/errors.rs +++ b/kernel/src/network/errors.rs @@ -5,7 +5,7 @@ pub enum NetworkErrors { PortInUse, /// No ports are open NoAvailablePort, - /// The destination is not reachable (can't resolve with ARP) + /// The destination is not reachable (can't resolve with ARP or won't get a response) NonexistentHost, /// The socket is in a bad socket state and we cannot perform the operation for you BadSocketState, diff --git a/kernel/src/network/init.rs b/kernel/src/network/init.rs index 36ada72..f184530 100644 --- a/kernel/src/network/init.rs +++ b/kernel/src/network/init.rs @@ -39,7 +39,7 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { let packet_data = dhcp.serialize(); // Send the first BOOTP packet - rtl_dev_info.send_packet(&packet_data); + rtl_dev_info.send_packet(&packet_data, BROADCAST_ADDR); // Release the driver object drop(rtl_dev_guard); @@ -61,7 +61,7 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { let rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); // send another packet - rtl_dev_info.send_packet(&packet_data); + rtl_dev_info.send_packet(&packet_data, BROADCAST_ADDR); } // Don't try forever retries += 1; @@ -75,11 +75,12 @@ pub async fn init_dhcp(wait_timeout: u8) -> bool { // Get the network stack object let mut rtl_dev_config = NET_INFO.config.lock(); - if pkt_data.get_type() == LayerType::Dhcp { + if pkt_data.get_type() == LayerType::Dhcp { // With the driver object, record the information from the DHCP server let dhcp_res = pkt_data.unwrap_dhcp(); rtl_dev_config.my_ip_address = Some(dhcp_res.my_ip.val()); rtl_dev_config.dhcp_server_ip = Some(dhcp_res.server_ip.val()); + NET_INFO.lock().get_mut().unwrap().ip_address = Some(dhcp_res.my_ip.val()); // Debug print my IP let ip = dhcp_res.my_ip.swapped(); println!("[INFO] IP-Address Assigned As {}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]); diff --git a/kernel/src/network/ip.rs b/kernel/src/network/ip.rs index 4952451..2203953 100644 --- a/kernel/src/network/ip.rs +++ b/kernel/src/network/ip.rs @@ -36,7 +36,7 @@ impl Protocol { } /// A IP packet, implements Layer and HasChecksum (20 bytes) -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct IPPacket { /// The parent packet pub eth: EthernetPacket, @@ -270,5 +270,12 @@ pub fn test() -> Result<(), String> { ip_res2.get_type() == LayerType::Err, "Deserializing an ip packet without enough data is an error" ); + + // Check parse with wrong ip header length + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let mut ip = IPPacket::gen(eth, 0, Protocol::Udp, 3, 4); + ip.total_length = Bytefield16::new(19); // -1 than normal + let (_, ip_res3) = full_parse(&ip.serialize()); + check!(ip_res3.get_type() == LayerType::Err, "IP packet with too small length is an error"); test_ok!(); } diff --git a/kernel/src/network/layer.rs b/kernel/src/network/layer.rs index 72cec01..629b7ec 100644 --- a/kernel/src/network/layer.rs +++ b/kernel/src/network/layer.rs @@ -313,15 +313,16 @@ pub fn test() -> Result<(), String> { (tcp.serialize(), LayerType::Tcp), ]; - for pkt in packets.iter() { - let mut data = pkt.0.clone(); + for (pkt_data, type_of) in packets.iter() { + let mut data = pkt_data.clone(); let (_, undo_pkt1) = full_parse(&data); - check!(undo_pkt1.get_type() == pkt.1, "Checking type of deserialized packet"); + check!(undo_pkt1.get_type() == *type_of, "Checking type of deserialized packet"); // 1 above in amount data.push(0); let (_, err_pkt1) = full_parse(&data); - check!(err_pkt1.get_type() == pkt.1, "Checking type of deserialized packet"); + // ignoring arp for this, because we don't have a strict length limit + check!(err_pkt1.get_type() == *type_of, "Checking type of deserialized packet"); // 1 under in amount data.pop(); diff --git a/kernel/src/network/netsync.rs b/kernel/src/network/netsync.rs index 0a42231..79dd92e 100644 --- a/kernel/src/network/netsync.rs +++ b/kernel/src/network/netsync.rs @@ -89,7 +89,9 @@ impl InterruptCounter { } /// Decrement the value pub fn dec(&mut self) { - self.data -= 1; + if self.data > 0 { // safe decrement + self.data -= 1; + } } } diff --git a/kernel/src/network/network_query.rs b/kernel/src/network/network_query.rs index 57c424f..6818f9b 100644 --- a/kernel/src/network/network_query.rs +++ b/kernel/src/network/network_query.rs @@ -88,7 +88,9 @@ impl NetworkQuery { /// wait_timeout is how many seconds until giving up pub async fn get_mac_from_ip(wait_timeout: u32, dst_ip: u32) -> Option { let mut rtl_dev_config = NET_INFO.config.lock(); - + if dst_ip == NetworkQuery::parse_ip("127.0.0.1").unwrap() { + return rtl_dev_config.my_mac_address; + } // iterate through entries in the arp table for (index, entry) in rtl_dev_config.arp_table.iter().enumerate() { // if entry matches, we can return from the cache @@ -112,7 +114,7 @@ impl NetworkQuery { // Otherwise, we send an arp packet and query the IP let eth_layer = EthernetPacket::gen(BROADCAST_MAC, rtl_dev_info.mac_address.unwrap(), EthType::Arp); let arp_layer = ArpPacket::gen(eth_layer, my_ip, dst_ip, true); - rtl_dev_info.send_packet(&arp_layer.serialize()); + rtl_dev_info.send_packet(&arp_layer.serialize(), dst_ip); // and release the driver drop(rtl_dev_info_locked); @@ -144,7 +146,7 @@ impl NetworkQuery { let rtl_dev_guard = NET_INFO.lock(); let rtl_dev_info = rtl_dev_guard.get_ref().unwrap(); // send the packet - rtl_dev_info.send_packet(&arp_layer.serialize()); + rtl_dev_info.send_packet(&arp_layer.serialize(), dst_ip); } // Count retries and if we exceed the limit, we die retries += 1; diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index 4e5e618..c60d250 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -108,7 +108,8 @@ pub(crate) fn add_pkt_data(data: PacketBuf) { /// (so not very useful beyond testing our network stack) /// If you IPC, don't use this pub(crate) fn local_send_pkt(data: &[u8]) { - // Try to get the array queue + // todo: what if queue is full? + // Try to get the array queue if let Ok(queue) = PENDING_DATA.try_get() { // Simple transfer into an array let mut good: [u8; 1500] = [0; 1500]; @@ -135,17 +136,14 @@ pub async fn init_packet_processing() { // Get the next piece of data the interrupt handler drew from the card while let Some(pkt_data) = raw_packets.next().await { // Parse it - let amount_parsed_and_pkt = full_parse(&pkt_data.buffer); + let amount_parsed_and_pkt = full_parse(&pkt_data.buffer[..pkt_data.length as usize]); if amount_parsed_and_pkt.1.get_type() == LayerType::Err || amount_parsed_and_pkt.1.get_type() == LayerType::Icmp { // Don't deal with unrecognized packets (will fail assert) continue; } if amount_parsed_and_pkt.0 != pkt_data.length as usize && (pkt_data.length as usize) > 64 { // Assert we had a proper amount of data processed without erroring out - println!( - "[ERR] amount parsed: {} --> true length: {}", - amount_parsed_and_pkt.0, pkt_data.length - ); + println!("[ERR] amount parsed: {}, true: {}", amount_parsed_and_pkt.0, pkt_data.length); assert!(amount_parsed_and_pkt.0 == pkt_data.length as usize || (pkt_data.length as usize) <= 64); } @@ -176,7 +174,7 @@ pub async fn init_packet_processing() { let arp_layer = ArpPacket::gen(eth_layer, ip_address, arp.src_ip.val(), false); let arp_pkt = arp_layer.serialize(); // and send it - NET_INFO.lock().get_mut().unwrap().send_packet(&arp_pkt); + NET_INFO.lock().get_mut().unwrap().send_packet(&arp_pkt, arp.src_ip.val()); } } PacketData::DHCP(mut dhcp) => { @@ -255,38 +253,39 @@ pub async fn init_packet_processing() { let tcp_layer = TCPPacket::gen(ip_layer, tcp.dest_port.val(), tcp.src_port.val()); // Create a session with the template let session = TCPSession::new(tcp_layer, tcp.ip.src_ip.val(), tcp.src_port.val(), tcp.dest_port.val()); - // Lets upstream our received packet to the listening port (and wake it) - // - This is used to split the socket into an "Established" session socket and a "Listening" socket - rtl_dev_config - .ports - .get_mut(&dest_port) - .unwrap() - .push_back(Ok(PacketData::TCP(tcp.clone()))); - wake_sockets(dest_port); // Finally insert our tcp session rtl_dev_config.tcp_sessions.insert(session.session_key(), session); - } + } // Get the tcp session let tcp_session = rtl_dev_config.tcp_sessions.get_mut(&session_key).unwrap(); // Generate an acknowledgement via the session's process_recv function - let ack_pkt = tcp_session.process_recv(&tcp); - if let Some(response) = ack_pkt.0 { + let (ack_pkt, recv_action) = tcp_session.process_recv(&tcp); + if let Some(response) = ack_pkt { // If we got a response packet to send back, send it // If no ack is received, the host will send another transmission for us to respond to - NET_INFO.lock().get_mut().unwrap().send_packet(&response.serialize()); + NET_INFO.lock().get_mut().unwrap().send_packet(&response.serialize(), response.ip.dest_ip.swapped().val()); } - // todo: Release TCP resources after 5 minutes if never closed? + // todo: Release TCP-session resources after 5 minutes if socket never closed? // Interpret the action from the process_recv function - if ack_pkt.1 != SessionAction::Drop { + if recv_action == SessionAction::EstablishedSession { + // Upstream our received packet to the listening port (and wake it) + // This is used to split the socket into an "Established" session socket and a "Listening" socket + rtl_dev_config + .ports + .get_mut(&dest_port) + .unwrap() + .push_back(Ok(PacketData::TCP(tcp.clone()))); + wake_sockets(dest_port); + } else if recv_action != SessionAction::Drop { // if we are listening on the session, try to init the packet queue on that end if !rtl_dev_config.ports.contains_key(&session_key) { rtl_dev_config.ports.insert(session_key, VecDeque::new()); } - let res = if ack_pkt.1 == SessionAction::PushUpstream { + let res = if recv_action == SessionAction::PushUpstream { // If the action is upstream, then we push the packet to the raw socket to handle it's data (if present) Ok(PacketData::TCP(tcp)) - } else if ack_pkt.1 == SessionAction::EndOfStream { + } else if recv_action == SessionAction::EndOfStream { // If the action is end of stream, we push the sentinel to indicate the stream has finished Err(NetworkErrors::ClosedSocket) } else { @@ -306,7 +305,7 @@ pub async fn init_packet_processing() { pub async fn test() -> Result<(), String> { - mark_as_test!("Processing (pipeline stage 2)"); + mark_as_test!("Processing (stage 2)"); let eth = EthernetPacket::gen(1, 2, EthType::IPv4); let ip = IPPacket::gen(eth, UDPPacket::packet_size() + 10, Protocol::Udp, 3, 4); @@ -328,11 +327,9 @@ pub async fn test() -> Result<(), String> { local_send_pkt(&good); } local_send_pkt(&bad); - serial_print!("-"); // Process the packets until we run out let mut i = 0; while i != 19 { - serial_print!("."); if let Some(pkt_or_err) = socket.next().await { // Check timeouts if let Err(err) = pkt_or_err { diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index 8ff1733..223db3b 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -23,9 +23,10 @@ lazy_static! { } /// A raw socket object to poll for packets on the network stack +#[derive(Debug)] pub struct RawSocket { /// The port that raw socket owns - port: u64, + pub port: u64, /// How many epochs (of the timer interrupt) until we spontaneously reawaken our processing timeout_in_epochs: u16, /// If the timeout is registered on this raw socket @@ -131,7 +132,7 @@ pub(crate) fn wake_sockets(port: u64) { } pub async fn test() -> Result<(), String> { - mark_as_test!("RawSocket (pipeline stage 3)"); + mark_as_test!("RawSocket (stage 3)"); let rtl_config_guard = NET_INFO.config.lock(); let my_ip = rtl_config_guard.my_ip_address.unwrap_or(0); let my_mac = rtl_config_guard.my_mac_address.unwrap_or(0); diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 20d2cea..8c9fe27 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -1,5 +1,6 @@ use core::cmp::max; +use alloc::string::String; use alloc::vec::Vec; use alloc::{collections::VecDeque, vec}; use hashbrown::{HashMap, HashSet}; @@ -21,11 +22,12 @@ use crate::network::constants::{ RX_PHYSICAL_MATCH, RX_PROMISCUOUS, RX_READ_PTR_MASK, RX_START_REG, }; use crate::network::raw_array::WrappingRawArray; -use crate::serial_println; +use crate::{serial_print, serial_println, test_ok, mark_as_test, check}; use super::constants::{IMR_REG, INTERRUPT_MASK, ISR_REG, ROK, RTL_DEV, RTL_VEND, TOK, TRANSMIT_CMD, TRANSMIT_REG}; use super::errors::NetworkErrors; -use super::processing::{add_pkt_data, PacketBuf}; +use super::network_query::NetworkQuery; +use super::processing::{add_pkt_data, PacketBuf, local_send_pkt}; use super::tcp_session::TCPSession; use super::{ arp_table::ArpEntry, @@ -189,13 +191,13 @@ fn recv_packet(rtl_dev_info: &RTL8139) { } } -// TODO: Split the driver into separate bits so we can lock individual resources? pub struct RTL8139 { pub device_info: Device, recv_buffer: Option, // 12KB send_buffer: Option, // 12KB physical_mem_offset: Option, pub mac_address: Option, + pub ip_address: Option, } pub struct NetworkConfig { @@ -282,6 +284,7 @@ impl RTL8139 { recv_buffer: None, send_buffer: None, mac_address: None, + ip_address: None, physical_mem_offset: None, }); } @@ -386,7 +389,16 @@ impl RTL8139 { // ! ^^ /// Write the packet data to the card and notify /// This will send the packet - pub fn send_packet(&self, packet_data: &Vec) { + /// - dest_ip only used to determine which interface to send on (local or network) + /// Will return true if send on the network, false if local + pub fn send_packet(&self, packet_data: &Vec, dest_ip: u32) -> bool { + let local = NetworkQuery::parse_ip("127.0.0.1").unwrap(); + let my_ip = self.ip_address.unwrap_or(local); + if dest_ip == local || dest_ip == my_ip { + // If we have a packet that is localhost, send it on the local interface instead + local_send_pkt(packet_data); + return false; + } // If our buffers are invalid, we panic if self.send_buffer.is_none() || self.physical_mem_offset.is_none() { panic!("RTL8139 is not initialized properly"); @@ -394,7 +406,7 @@ impl RTL8139 { // If we don't have a mac address, stop let io_base = self.device_info.io_base; if self.mac_address.is_none() || io_base.is_none() { - return; + panic!("No mac address or io base"); } // Create a virtual address in which we can write the contents of the packet let virtual_buffer: VirtAddr = VirtAddr::new(self.send_buffer.unwrap().as_u64() + self.physical_mem_offset.unwrap()); @@ -426,8 +438,42 @@ impl RTL8139 { TRANSMIT_IDX += 1; TRANSMIT_IDX %= 4; }; + true } } // Note: to actually test the network driver, we need to put packets on the network // For this, I wrote a python script instead of an in-OS test. +// We are going to test the interrupt handling +pub fn test() -> Result<(), String> { + mark_as_test!("RTL Interrupts (stage 1)"); + // Check interrupt disabling behavior + check!(are_network_interrupts_enabled(), "Network interrupts are enabled by default"); + enable_network_interrupts(); // 0 + check!(are_network_interrupts_enabled(), "Network interrupts are enabled by default, should have no effect"); + disable_network_interrupts(); // 1 + check!(!are_network_interrupts_enabled(), "Network interrupts are enabled by default, should have no effect"); + enable_network_interrupts(); // 0 + check!(are_network_interrupts_enabled(), "Network interrupts are enabled by default, should have no effect"); + disable_network_interrupts(); // 1 + disable_network_interrupts(); // 2 + check!(!are_network_interrupts_enabled(), "Network interrupts are enabled by default, should have no effect"); + enable_network_interrupts(); // 1 + check!(!are_network_interrupts_enabled(), "Network interrupts are enabled by default, should have no effect"); + enable_network_interrupts(); // 0 + check!(are_network_interrupts_enabled(), "Network interrupts are enabled by default, should have no effect"); + disable_network_interrupts(); // 1 + disable_network_interrupts(); // 2 + disable_network_interrupts(); // 3 + disable_network_interrupts(); // 4 + disable_network_interrupts(); // 5 + enable_network_interrupts(); // 4 + enable_network_interrupts(); // 3 + enable_network_interrupts(); // 2 + enable_network_interrupts(); // 1 + check!(!are_network_interrupts_enabled(), "Network interrupts are enabled by default, should have no effect"); + enable_network_interrupts(); // 0 + check!(are_network_interrupts_enabled(), "Network interrupts are enabled by default, should have no effect"); + + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 2dbbf71..2d95f24 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -8,6 +8,11 @@ use futures_util::StreamExt; use crate::{test_ok, mark_as_test, check, serial_print, serial_println}; use crate::{network::layer::LayerType, println}; +use super::bytefield::Bytefield32; +use super::constants::TCP_SYN; +use super::ethernet::EthType; +use super::tcp::TCPPacket; + use super::{ constants::TCP_PSH, errors::NetworkErrors, @@ -42,6 +47,7 @@ pub struct Socket { socket_type: SocketType, socket_state: SocketState, raw_socket: RawSocket, + backup_socket: Option, dest_port: u16, dest_ip: u32, dest_mac: u64, @@ -96,6 +102,7 @@ impl Socket { Ok(socket) => Ok(Socket { socket_type, raw_socket: socket, + backup_socket: None, socket_state: SocketState::Listening, dest_port: 0, dest_ip: 0, @@ -158,10 +165,10 @@ impl Socket { // Then we unwrap the TCP packet let tcp_pkt = pkt.unwrap_tcp(); // And extract and save the dest_address and dest_port - let dest_address = tcp_pkt.ip.src_ip.val(); + let dest_ip = tcp_pkt.ip.src_ip.val(); let dest_port = tcp_pkt.src_port.val(); // We create a session key - let session_key = TCPSession::gen_session_key(dest_address, dest_port, self.src_port); + let session_key = TCPSession::gen_session_key(dest_ip, dest_port, self.src_port); // And a raw socket let raw_socket = RawSocket::new(session_key, max(self.wait_timeout * 18, 1)).unwrap(); println!("[INFO] Spawned new TCP session"); @@ -170,9 +177,10 @@ impl Socket { return Ok(Some(Socket { socket_type: SocketType::TCP, raw_socket, + backup_socket: None, socket_state: SocketState::Ready, dest_port, - dest_ip: dest_address, + dest_ip, dest_mac: tcp_pkt.ip.eth.src_mac.val(), src_port: self.src_port, src_address: self.src_address, @@ -186,14 +194,8 @@ impl Socket { } /// Connect to a foreign socket that is listening - // todo: untested - pub async fn connect( - socket_type: SocketType, - dest_address: u32, - dest_port: u16, - src_port: u16, - wait_timeout: u16, - ) -> Result { + #[allow(clippy::question_mark)] + pub async fn connect(socket_type: SocketType, dest_ip: u32, dest_port: u16, src_port: u16, wait_timeout: u16) -> Result { // Get the chosen port let mut chosen_src_port = src_port; @@ -202,7 +204,7 @@ impl Socket { // Extract src mac and src address let src_mac = rtl_dev_config.my_mac_address.unwrap(); - let src_address = rtl_dev_config.my_ip_address.unwrap(); + let src_ip = rtl_dev_config.my_ip_address.unwrap(); // if src port is 0 if src_port == 0 { @@ -221,7 +223,7 @@ impl Socket { drop(rtl_dev_config); // Also query for a destination mac address - let dest_mac = NetworkQuery::get_mac_from_ip(10, dest_address).await; + let dest_mac = NetworkQuery::get_mac_from_ip(10, dest_ip).await; // If destination mac is none, we have a non-existent host error if dest_mac.is_none() { @@ -232,24 +234,81 @@ impl Socket { return Err(NetworkErrors::NoAvailablePort); } // Create a raw socket to work with - let raw_socket = RawSocket::new(chosen_src_port as u64, max(wait_timeout * 18, 1)); - // If the raw socket was created successfully, we return a new socket object - match raw_socket { - Ok(socket) => Ok(Socket { - socket_type, - socket_state: SocketState::Ready, - raw_socket: socket, - dest_port, - dest_ip: dest_address, - dest_mac: dest_mac.unwrap(), - src_port, - src_address, - src_mac, - wait_timeout, - session_key: 0, // todo: Connect should generate a tcp session... - }), - Err(err) => Err(err), + let raw_socket_or_err = RawSocket::new(chosen_src_port as u64, max(wait_timeout * 18, 1)); + if let Err(err) = raw_socket_or_err { + return Err(err); + } + let mut session_socket = raw_socket_or_err.unwrap(); + let mut port_socket = None; + + // Setup tcp session + if socket_type == SocketType::TCP { + // Create the session key + let session_key = TCPSession::gen_session_key(dest_ip, dest_port, src_port); + // Setup the sockets + port_socket = Some(session_socket); + let raw_socket_or_err = RawSocket::new(session_key, max(wait_timeout * 18, 1)); + if let Err(err) = raw_socket_or_err { + port_socket.unwrap().close(); + return Err(err); + } + session_socket = raw_socket_or_err.unwrap(); + // Get network stack object + let mut rtl_dev_config = NET_INFO.config.lock(); + // Guard if session exists + if rtl_dev_config.tcp_sessions.contains_key(&session_key) { + return Err(NetworkErrors::PortInUse); + } + // Compact a template packet + let eth_layer = EthernetPacket::gen(dest_mac.unwrap(), src_mac, EthType::IPv4); + // IMP: leaving size undefined for the template (since data size will change) + let ip_layer = IPPacket::gen(eth_layer, TCPPacket::packet_size(), Protocol::Tcp, src_ip, dest_ip); + let mut tcp_layer = TCPPacket::gen(ip_layer, src_port, dest_port); + // Create a session with the template + let session = TCPSession::new(tcp_layer.clone(), dest_ip, dest_port, src_port); + let session_key = session.session_key(); + // Calculate checksums and fill in the seq/ack nums + tcp_layer.turn_on_flags(TCP_SYN); + tcp_layer.seq_num = Bytefield32::new(session.sent_data_amount); + tcp_layer.ack_num = Bytefield32::new(session.recv_data_amount); + tcp_layer.ip.calculate_checksum(); + tcp_layer.calculate_checksum(); + + // Finally insert our tcp session into the network stack object + rtl_dev_config.tcp_sessions.insert(session_key, session); + drop(rtl_dev_config); + + // Give 5 tries to make the connection + let mut found = false; + for _ in 0..5 { + NET_INFO.lock().get_ref().unwrap().send_packet(&tcp_layer.serialize(), dest_ip); + // Loop on packets from the raw socket, until we get a packet we must still sync + if let Some(Ok(_)) = port_socket.as_mut().unwrap().next().await { + found = true; + break; + } + } + // We ran out of retries, exiting with an error + if !found { + return Err(NetworkErrors::NonexistentHost); + } + } + // If the raw socket was created successfully, we return a new socket object + Ok(Socket { + socket_type, + socket_state: SocketState::Ready, + raw_socket: session_socket, + backup_socket: port_socket, + dest_port, + dest_ip, + dest_mac: dest_mac.unwrap(), + src_port, + src_address: src_ip, + src_mac, + wait_timeout, + session_key: TCPSession::gen_session_key(dest_ip, dest_port, src_port), + }) } /// Close the socket and release the resources associated with it @@ -282,7 +341,7 @@ impl Socket { // Send the FIN-ACK packet let mut rtl_dev_info_locked = NET_INFO.lock(); let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); - rtl_dev_info.send_packet(&pkt.serialize()); + rtl_dev_info.send_packet(&pkt.serialize(), self.dest_ip); drop(rtl_dev_info_locked); loop { @@ -304,6 +363,10 @@ impl Socket { } // Close the raw socket self.raw_socket.close(); + // Close the backup if it exists + if let Some(backup) = self.backup_socket { + backup.close(); + } } /// Read data from the socket @@ -432,7 +495,7 @@ impl Socket { let packet_data = udp_layer.serialize(); // Send the packet - NET_INFO.lock().get_ref().unwrap().send_packet(&packet_data); + NET_INFO.lock().get_ref().unwrap().send_packet(&packet_data, self.dest_ip); // Return how much was written Ok(data_len as u16) @@ -464,18 +527,14 @@ impl Socket { return Err(err); } // Unwrap the message - let message = message_pkt.unwrap().serialize(); + let tmp = message_pkt.unwrap(); + let message = tmp.serialize(); drop(tcp_session_guard); - // Wait for the ack -- 20 retries - for retries in 1..21 { - // Acquire the driver - let mut rtl_dev_info_locked = NET_INFO.lock(); - let rtl_dev_info = rtl_dev_info_locked.get_mut().unwrap(); + // Wait for the ack -- 5 retries + for retries in 1..6 { // Send the packet - rtl_dev_info.send_packet(&message); - // Release the driver - drop(rtl_dev_info_locked); + NET_INFO.lock().get_ref().unwrap().send_packet(&message, self.dest_ip); // Get next packet or timeout if let Some(pkt_or_err) = self.raw_socket.next().await { @@ -495,7 +554,7 @@ impl Socket { } } } - if retries == 20 { + if retries == 5 { // We reach too many timeouts so we return an error return Err(NetworkErrors::Timeout); } @@ -505,11 +564,71 @@ impl Socket { } } +pub async fn test() -> Result<(), String> { + mark_as_test!("Socket (stage 4)"); + // UDP socket first + let my_ip = NetworkQuery::parse_ip("127.0.0.1").unwrap(); + let recv_wrapped = Socket::open(SocketType::UDP, 50001, 10).await; + let send_wrapped = Socket::connect(SocketType::UDP, my_ip, 50001, 50000, 1).await; + check!(send_wrapped.is_ok(), "Opening socket is ok"); + check!(recv_wrapped.is_ok(), "Opening second socket is ok"); + let mut send = send_wrapped.unwrap(); + let mut recv = recv_wrapped.unwrap(); + let not_err = send.reliable_write(&mut vec![0, 1]).await; + check!(not_err.is_none(), "Writing is not an error"); + let udp_conn = recv.listen().await; + check!(udp_conn.is_ok() && udp_conn.unwrap().is_none(), "Can make the connection"); + let data = recv.read(2).await; + check!(data.is_ok(), "Data was received ok"); + check!(data.unwrap() == vec![0, 1], "Data is correct"); + + for _ in 0..1000 { + // writing 200 bytes + let not_err = send.reliable_write(&mut vec![0; 200]).await; + check!(not_err.is_none(), "Writing is not an error"); + let data = recv.read(200).await; + check!(data.is_ok(), "Data was received ok"); + } + + send.close().await; + recv.close().await; + + // Open TCP server socket + let recv_wrapped_listener = Socket::open(SocketType::TCP, 60001, 2).await; + check!(recv_wrapped_listener.is_ok(), "Opening server socket is ok"); + let mut server_socket = recv_wrapped_listener.unwrap(); + + // TCP socket, create 20 different connection iterations + for _ in 0..20 { + let send_wrapped = Socket::connect(SocketType::TCP, my_ip, 60001, 60000, 2).await; + check!(send_wrapped.is_ok(), "Opening socket is ok"); + + let recv_wrapped = server_socket.listen().await; + check!(recv_wrapped.is_ok(), "Listening socket is ok"); + let mut send = send_wrapped.unwrap(); + let mut recv = recv_wrapped.unwrap().unwrap(); + + // Send data 10 times + for _ in 0..10 { + // Writing 200 bytes + let not_err = send.reliable_write(&mut vec![0; 200]).await; + if not_err.is_some() { + serial_println!("[ERR] {:?}", not_err); + } + check!(not_err.is_none(), "Writing is not an error"); + let data = recv.read(200).await; + check!(data.is_ok(), "Data was received ok"); + } + send.close().await; + recv.close().await; + } -pub async fn test() -> Result<(), String> { - mark_as_test!("Socket"); - check!(true, "Yay!"); - serial_println!("todo: this test file, using local_send"); + // todo: Errors with sending in bad states (without listening) + // todo: TCP server with multiple clients + // todo: closing with no good state + // todo: sending multiple packets and read in chunks + // todo: what happens if we both write and then read... !! + // todo: data comes completely test_ok!(); } \ No newline at end of file diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index 5562cf2..b7dd833 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -1,6 +1,12 @@ -use alloc::vec; +use alloc::{vec, string::String}; use alloc::vec::Vec; +use crate::network::constants::{TCP_SYN, TCP_ACK, TCP_FIN, TCP_RST}; +use crate::network::ethernet::EthType; +use crate::network::ip::Protocol; +use crate::network::layer::full_parse; +use crate::{mark_as_test, check, test_ok, serial_print, serial_println}; + use super::{ bytefield::{Bytefield16, Bytefield32}, ethernet::EthernetPacket, @@ -9,7 +15,7 @@ use super::{ }; /// A TCP packet, implements Layer and HasChecksum (20 bytes) -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct TCPPacket { /// The parent packet pub ip: IPPacket, @@ -44,7 +50,7 @@ impl TCPPacket { dest_port: Bytefield16::new(0), seq_num: Bytefield32::new(0), ack_num: Bytefield32::new(0), - flags: Bytefield16::new(0), + flags: Bytefield16::new(5 << 12), sliding_window: Bytefield16::new(0), checksum: Bytefield16::new(0), urgent: Bytefield16::new(0), @@ -82,14 +88,14 @@ impl TCPPacket { 0 } - /// N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) + /// N.B. Getting operations DON'T swap endianness because we should be in host order (and after parsing we are) /// Get the header offset pub fn get_header_offset(&self) -> u8 { // convert header offset into bytes ((self.flags.val() >> 12) as u8) * 4 } - /// N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) + /// N.B. Getting operations DON'T swap endianness because we should be in host order (and after parsing we are) /// Get the flags pub fn get_flags(&self) -> u8 { (self.flags.val() & 0x00FF) as u8 @@ -102,7 +108,29 @@ impl TCPPacket { self.flags = Bytefield16::new(new_flags); } - /// N.B. Getting operations DONT swap endianness because we should be in host order (and after parsing we are) + /// [TESTING FUNCTION]. Will swap order of every field + pub fn swapped(&self) -> Self { + let mut new = self.clone(); + new.ip.eth.dest_mac = new.ip.eth.dest_mac.swapped(); + new.ip.eth.src_mac = new.ip.eth.src_mac.swapped(); + new.ip.dest_ip = new.ip.dest_ip.swapped(); + new.ip.src_ip = new.ip.src_ip.swapped(); + new.ip.checksum = new.ip.checksum.swapped(); + new.ip.total_length = new.ip.total_length.swapped(); + new.ip.identification = new.ip.identification.swapped(); + new.ip.flags_fragment_offset = new.ip.flags_fragment_offset.swapped(); + new.ack_num = new.ack_num.swapped(); + new.seq_num = new.seq_num.swapped(); + new.checksum = new.checksum.swapped(); + new.dest_port = new.dest_port.swapped(); + new.src_port = new.src_port.swapped(); + new.flags = new.flags.swapped(); + new.sliding_window = new.sliding_window.swapped(); + new.urgent = new.urgent.swapped(); + new + } + + /// N.B. Getting operations DON'T swap endianness because we should be in host order (and after parsing we are) /// Get total length of the tcp portion pub fn total_size(&self) -> u16 { (20 + self.options.len() + self.data.len()) as u16 @@ -195,7 +223,7 @@ impl HasChecksum for TCPPacket { // Starting vars let mut sum: u32 = 0; - // First we do the IP as a pseduo header + // First we do the IP as a pseudo header let ip = &self.ip; sum += (ip.src_ip.data[0] as u32) | (ip.src_ip.data[1] as u32) << 8; sum += (ip.src_ip.data[2] as u32) | (ip.src_ip.data[3] as u32) << 8; @@ -278,3 +306,121 @@ impl HasChecksum for TCPPacket { } } } + + +pub fn test() -> Result<(), String> { + mark_as_test!("TCP Packet"); + + // Create a TCP packet, check it serializes correctly + let mut pkt: TCPPacket = TCPPacket::new(); + pkt.flags = pkt.flags.swapped(); // swap flags so that it is in host order + check!( + pkt.serialize().len() == (EthernetPacket::packet_size() + IPPacket::packet_size() + pkt.get_header_offset() as u16) as usize, + "Check serialization size" + ); + + // Create another TCP packet + let data = vec![1, 2, 3, 4, 5, 6, 7]; + let payload_size = data.len() as u16; + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let ip = IPPacket::gen(eth.clone(), payload_size + TCPPacket::packet_size(), Protocol::Tcp, 3, 4); + let mut tcp = TCPPacket::gen(ip, 5, 6); + tcp.data = data; + check!(tcp.src_port.swapped().val() == 5, "Check src port"); + check!(tcp.dest_port.swapped().val() == 6, "Check dest port"); + + // Serialize and deserialize + let mut serialized = tcp.serialize(); + let (count, should_be_tcp) = full_parse(&serialized); + check!(should_be_tcp.get_type() == LayerType::Tcp, "Check it's a TCP packet"); + let tcp_pkt = should_be_tcp.unwrap_tcp(); + check!( + count as u16 == tcp_pkt.get_header_offset() as u16 + IPPacket::packet_size() + EthernetPacket::packet_size() + payload_size, + "parse size" + ); + check!(tcp.src_port.swapped() == tcp_pkt.src_port, "Check field is same"); + check!(tcp.dest_port.swapped() == tcp_pkt.dest_port, "Check field is same"); + check!(tcp.ack_num.swapped() == tcp_pkt.ack_num, "Check field is same"); + check!(tcp.seq_num.swapped() == tcp_pkt.seq_num, "Check field is same"); + check!(tcp.flags.swapped() == tcp_pkt.flags, "Check field is same"); + check!(tcp.urgent.swapped() == tcp_pkt.urgent, "Check field is same"); + check!(tcp.sliding_window.swapped() == tcp_pkt.sliding_window, "Check field is same"); + check!(tcp.options == tcp_pkt.options, "Check options is same"); + check!(tcp.data == tcp_pkt.data, "Check data is same"); + + // Check parse on smaller vector than expected should also return an error + serialized.pop(); + let (_, should_be_err) = full_parse(&serialized); + check!(should_be_err.get_type() == LayerType::Err, "TCP packet has less data than promised"); + + // Create another TCP packet with no body, make it smaller than usual + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let ip = IPPacket::gen(eth.clone(), 0, Protocol::Tcp, 3, 4); + let mut tcp = TCPPacket::gen(ip, 5, 6); + let mut serialized = tcp.serialize(); + serialized.pop(); + let (_, should_be_err) = full_parse(&serialized); + check!(should_be_err.get_type() == LayerType::Err, "TCP packet has less header"); + + // Check flags + let parsed_tcp = tcp.swapped(); + let has_syn_flag = (parsed_tcp.get_flags() & TCP_SYN) != 0; + let has_ack_flag = (parsed_tcp.get_flags() & TCP_ACK) != 0; + let has_fin_flag = (parsed_tcp.get_flags() & TCP_FIN) != 0; + let has_rst_flag = (parsed_tcp.get_flags() & TCP_RST) != 0; + check!(!has_syn_flag, "syn flag"); + check!(!has_ack_flag, "ack flag"); + check!(!has_fin_flag, "fin flag"); + check!(!has_rst_flag, "rst flag"); + + tcp.turn_on_flags(TCP_SYN); + let parsed_tcp = tcp.swapped(); + + let has_syn_flag = (parsed_tcp.get_flags() & TCP_SYN) != 0; + let has_ack_flag = (parsed_tcp.get_flags() & TCP_ACK) != 0; + let has_fin_flag = (parsed_tcp.get_flags() & TCP_FIN) != 0; + let has_rst_flag = (parsed_tcp.get_flags() & TCP_RST) != 0; + check!(has_syn_flag, "syn flag"); + check!(!has_ack_flag, "ack flag"); + check!(!has_fin_flag, "fin flag"); + check!(!has_rst_flag, "rst flag"); + + tcp.turn_on_flags(TCP_ACK); + let parsed_tcp = tcp.swapped(); + + let has_syn_flag = (parsed_tcp.get_flags() & TCP_SYN) != 0; + let has_ack_flag = (parsed_tcp.get_flags() & TCP_ACK) != 0; + let has_fin_flag = (parsed_tcp.get_flags() & TCP_FIN) != 0; + let has_rst_flag = (parsed_tcp.get_flags() & TCP_RST) != 0; + check!(has_syn_flag, "syn flag"); + check!(has_ack_flag, "ack flag"); + check!(!has_fin_flag, "fin flag"); + check!(!has_rst_flag, "rst flag"); + + tcp.turn_on_flags(TCP_FIN); + let parsed_tcp = tcp.swapped(); + + let has_syn_flag = (parsed_tcp.get_flags() & TCP_SYN) != 0; + let has_ack_flag = (parsed_tcp.get_flags() & TCP_ACK) != 0; + let has_fin_flag = (parsed_tcp.get_flags() & TCP_FIN) != 0; + let has_rst_flag = (parsed_tcp.get_flags() & TCP_RST) != 0; + check!(has_syn_flag, "syn flag"); + check!(has_ack_flag, "ack flag"); + check!(has_fin_flag, "fin flag"); + check!(!has_rst_flag, "rst flag"); + + tcp.turn_on_flags(TCP_RST); + let parsed_tcp = tcp.swapped(); + + let has_syn_flag = (parsed_tcp.get_flags() & TCP_SYN) != 0; + let has_ack_flag = (parsed_tcp.get_flags() & TCP_ACK) != 0; + let has_fin_flag = (parsed_tcp.get_flags() & TCP_FIN) != 0; + let has_rst_flag = (parsed_tcp.get_flags() & TCP_RST) != 0; + check!(has_syn_flag, "syn flag"); + check!(has_ack_flag, "ack flag"); + check!(has_fin_flag, "fin flag"); + check!(has_rst_flag, "rst flag"); + + check!(parsed_tcp.swapped() == tcp, "Swapping back gives us the original packet"); + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index 39a94a3..96296c9 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -1,4 +1,6 @@ -use crate::{println, crypto::rng::get_next_random_num}; +use alloc::{string::String, vec}; + +use crate::{println, crypto::rng::get_next_random_num, mark_as_test, test_ok, check, serial_print, serial_println, network::{ethernet::{EthernetPacket, EthType}, ip::Protocol}}; use super::{ bytefield::{Bytefield16, Bytefield32}, @@ -27,6 +29,8 @@ pub enum TCPSessionState { /// An enum to be the action we take after processing a packet (receiving end only) #[derive(Debug, PartialEq, Eq)] pub enum SessionAction { + /// Push new session + EstablishedSession, /// Push a packet upstream to a higher level of abstraction (i.e. ports) PushUpstream, /// Drop a packet for lack of usefulness or bad state @@ -55,7 +59,7 @@ pub struct TCPSession { pub window_size: u16, /// If the user has sent fin_ack closing has_sent_fin_ack: bool, - /// If we have receieved an ack to our FIN-ACK message + /// If we have received an ack to our FIN-ACK message has_recv_ack_to_fin_ack: bool, /// If we have sent an ack too a FIN-ACK message has_sent_ack_to_fin_ack: bool, @@ -114,7 +118,6 @@ impl TCPSession { tcp_pkt.turn_on_flags(TCP_FIN | TCP_ACK); // Add the data size - // todo: (maybe make this automatic?) tcp_pkt.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); @@ -204,6 +207,9 @@ impl TCPSession { // Set our ack_num and seq_num for our response response.ack_num = Bytefield32::new(self.recv_data_amount); response.seq_num = Bytefield32::new(self.sent_data_amount); + // Transition to next state + self.session_state = TCPSessionState::Established; + response_action = SessionAction::EstablishedSession; } else if has_syn_flag { // If we received just a SYN, send back a SYN/ACK response.turn_on_flags(TCP_SYN | TCP_ACK); @@ -212,13 +218,15 @@ impl TCPSession { // Set our responses ack and seq num response.ack_num = Bytefield32::new(self.recv_data_amount); response.seq_num = Bytefield32::new(self.sent_data_amount); + // Transition to next state + self.session_state = TCPSessionState::Syncing; + response_action = SessionAction::Drop; } else { // No syn packet but the session hasn't been established // Therefore we drop the packet return (None, SessionAction::Drop); } - // Transition to next state - self.session_state = TCPSessionState::Syncing; + } TCPSessionState::Syncing => { if has_ack_flag { @@ -233,7 +241,7 @@ impl TCPSession { // Transition to next state self.session_state = TCPSessionState::Established; // Has no need for a response, but push upstream - return (None, SessionAction::PushUpstream); + return (None, SessionAction::EstablishedSession); } else { // Waiting on the ack packet // Dropping this packet @@ -243,16 +251,21 @@ impl TCPSession { TCPSessionState::Established => { // Default state is dropping let mut has_info = SessionAction::Drop; + if request.seq_num.val() != self.recv_data_amount { + // wrong seq num, drop the packet + return (None, SessionAction::Drop); + } + // Then update the ack num, if we need to if request.ack_num.val() > self.sent_data_acked && request.ack_num.val() <= self.sent_data_amount { // We are updating our record of how much data was acked self.sent_data_acked = request.ack_num.val(); // Even with no data, we push our packet upstream because --> - // the ack needs to release the write (since it blocks until its sure the write occured) + // the ack needs to release the write (since it blocks until its sure the write occurred) has_info = SessionAction::PushUpstream; } // Check sequence number for a match and if we have data - if request.seq_num.val() == self.recv_data_amount && !request.data.is_empty() { + if !request.data.is_empty() { // Then we increment our ack_num (what we've received) self.recv_data_amount += request.data.len() as u32; // Set response seq and ack num, and turn on ack flag @@ -260,7 +273,7 @@ impl TCPSession { response.seq_num = Bytefield32::new(self.sent_data_amount); response.turn_on_flags(TCP_ACK); } else { - // No data is present, so we push our packet upstream ONLY if it acked a write + // No data is present, so we push our packet upstream ONLY if it acked a write and has valid seq nums return (None, has_info); } } @@ -287,7 +300,7 @@ impl TCPSession { } // And push our packet upstream to release any closing sockets - return (None, SessionAction::PushUpstream); + return (None, SessionAction::EndOfStream); } // Increment ack_num @@ -305,9 +318,262 @@ impl TCPSession { // Calculate checksums response.ip.calculate_checksum(); response.calculate_checksum(); - // Return result and action (Some(response), response_action) } } + + +pub fn test() -> Result<(), String> { + mark_as_test!("TCP Session"); + // Template TCP packet + let eth = EthernetPacket::gen(1, 2, EthType::IPv4); + let ip = IPPacket::gen(eth, TCPPacket::packet_size(), Protocol::Tcp, 4, 4); + let mut tcp = TCPPacket::gen(ip, 5, 6); + + // ----------------- PART 1 ------------------- // + // Test state machine of tcp session + let mut session1 = TCPSession::new(tcp.clone(), 4, 6, 5); + let mut session2 = TCPSession::new(tcp.clone(), 4, 5, 6); + tcp.seq_num = Bytefield32::new(session1.sent_data_amount); + tcp.ack_num = Bytefield32::new(session1.recv_data_amount); + + // Generate packets for 3-way handshake + check!(session1.session_state == TCPSessionState::Waiting, "Starting at waiting"); + check!(session2.session_state == TCPSessionState::Waiting, "Starting at waiting"); + let (drop_no_syn, action) = session1.process_recv(&tcp.swapped()); + check!(drop_no_syn.is_none(), "No response packet, there is no SYN"); + check!(action == SessionAction::Drop, "Dropping packet that has no SYN"); + let mut syn = tcp.clone(); + syn.turn_on_flags(TCP_SYN); + + let (syn_ack, action) = session2.process_recv(&syn.swapped()); + check!(syn_ack.is_some(), "SYN packet generates syn_ack packet"); + check!(action == SessionAction::Drop, "Sending syn_ack packet"); + check!(session1.session_state == TCPSessionState::Waiting, "Session1 is still waiting"); + check!(session2.session_state == TCPSessionState::Syncing, "Session2 is now syncing"); + + let (ack, action) = session1.process_recv(&syn_ack.unwrap().swapped()); + check!(ack.is_some(), "SYN-ACK packet generates ack packet"); + check!(action == SessionAction::EstablishedSession, "Sending ack packet"); + check!(session1.session_state == TCPSessionState::Established, "Session1 is now established"); + check!(session2.session_state == TCPSessionState::Syncing, "Session2 is syncing"); + + let (done, action) = session2.process_recv(&ack.unwrap().swapped()); + check!(done.is_none(), "ACK packet generates no packet"); + check!(action == SessionAction::EstablishedSession, "Push ack packet upstream"); + check!(session1.session_state == TCPSessionState::Established, "Session1 is established"); + check!(session2.session_state == TCPSessionState::Established, "Session2 is now established"); + + // Sending data + for i in 0..100 { + // Invariants to start + check!(session1.session_state == TCPSessionState::Established, "Session1 is established"); + check!(session2.session_state == TCPSessionState::Established, "Session2 is now established"); + check!(session1.everything_acked(), "Everything acked to start"); + check!(session2.everything_acked(), "Everything acked to start"); + + // Send data on session1 + let data = vec![0; i + 1]; + let to_send = session1.process_send(&data); + check!(to_send.is_ok(), "Process send is ok"); + check!(!session1.everything_acked(), "Everything not acked yet"); + + // Get data on session2 + let (ack_back, action) = session2.process_recv(&to_send.unwrap().swapped()); + check!(ack_back.is_some(), "Process recv has response"); + check!(action == SessionAction::PushUpstream, "Process recv pushes upstream"); + check!(!session1.everything_acked(), "Everything not acked yet"); + + // Get ack on session1 + let (no_back, action) = session1.process_recv(&ack_back.unwrap().swapped()); + check!(no_back.is_none(), "Process recv-ack has no response"); + check!(action == SessionAction::PushUpstream, "Process recv-ack pushes upstream"); + check!(session1.everything_acked(), "Everything acked"); + } + + // Close session (4-way handshake), first fin packet + let fin_pkt = session1.close(); + check!(fin_pkt.is_ok(), "Created a fin packet"); + check!(session1.session_state == TCPSessionState::Closing, "Session1 is now closing"); + check!(session2.session_state == TCPSessionState::Established, "Session2 is still established"); + + // Send fin-ack packets back + let (fin_ack_pkt, action) = session2.process_recv(&fin_pkt.unwrap().swapped()); + check!(fin_ack_pkt.is_some(), "Created a fin-ack packet"); + check!(action == SessionAction::EndOfStream, "Should be end of stream now"); + check!(session1.session_state == TCPSessionState::Closing, "Session1 is closing"); + check!(session2.session_state == TCPSessionState::Closing, "Session2 is now closing"); + + // Receive fin-ack packet + let (no_res, action) = session1.process_recv(&fin_ack_pkt.unwrap().swapped()); + check!(no_res.is_none(), "No more packets"); + check!(action == SessionAction::EndOfStream, "Should be end of stream now"); + check!(session1.session_state == TCPSessionState::Closing, "Session1 is closing"); + check!(session2.session_state == TCPSessionState::Closing, "Session2 is closing"); + + // Can close session2 now + let fin_pkt = session2.close(); + check!(fin_pkt.is_ok(), "Created a fin packet"); + check!(session1.session_state == TCPSessionState::Closing, "Session1 is closing"); + check!(session2.session_state == TCPSessionState::Closing, "Session2 is closing"); + + // Send fin-ack packets back + let (fin_ack_pkt, action) = session1.process_recv(&fin_pkt.unwrap().swapped()); + check!(fin_ack_pkt.is_some(), "Created a fin-ack packet"); + check!(action == SessionAction::EndOfStream, "Should be end of stream"); + check!(session1.session_state == TCPSessionState::Closed, "Session1 is now closed"); + check!(session2.session_state == TCPSessionState::Closing, "Session2 is closing"); + + // Recv fin-ack + let (no_res, action) = session2.process_recv(&fin_ack_pkt.unwrap().swapped()); + check!(no_res.is_none(), "No more packets"); + check!(action == SessionAction::EndOfStream, "Should be end of stream now"); + check!(session1.session_state == TCPSessionState::Closed, "Session1 is closed"); + check!(session2.session_state == TCPSessionState::Closed, "Session2 is now closed"); + + // ----------------- PART 2 ------------------- // + // Test state machine of tcp session + let mut session1 = TCPSession::new(tcp.clone(), 4, 6, 5); + let mut session2 = TCPSession::new(tcp.clone(), 4, 5, 6); + + // Generate packets for 3-way handshake + check!(session1.session_state == TCPSessionState::Waiting, "Starting at waiting"); + check!(session2.session_state == TCPSessionState::Waiting, "Starting at waiting"); + let (drop_no_syn, action) = session1.process_recv(&tcp.swapped()); + check!(drop_no_syn.is_none(), "No response packet, there is no SYN"); + check!(action == SessionAction::Drop, "Dropping packet that has no SYN"); + // set the seq and ack num + tcp.seq_num = Bytefield32::new(session1.sent_data_amount); + tcp.ack_num = Bytefield32::new(session1.recv_data_amount); + let mut syn = tcp.clone(); + syn.turn_on_flags(TCP_SYN); + + let (syn_ack, action) = session2.process_recv(&syn.swapped()); + check!(syn_ack.is_some(), "SYN packet generates syn_ack packet"); + check!(action == SessionAction::Drop, "Sending syn_ack packet"); + check!(session1.session_state == TCPSessionState::Waiting, "Session1 is still waiting"); + check!(session2.session_state == TCPSessionState::Syncing, "Session2 is now syncing"); + + let (ack, action) = session1.process_recv(&syn_ack.clone().unwrap().swapped()); + check!(ack.is_some(), "SYN-ACK packet generates ack packet"); + check!(action == SessionAction::EstablishedSession, "Sending ack packet"); + check!(session1.session_state == TCPSessionState::Established, "Session1 is now established"); + check!(session2.session_state == TCPSessionState::Syncing, "Session2 is syncing"); + + let (done, action) = session2.process_recv(&ack.clone().unwrap().swapped()); + check!(done.is_none(), "ACK packet generates no packet"); + check!(action == SessionAction::EstablishedSession, "Push ack packet upstream"); + check!(session1.session_state == TCPSessionState::Established, "Session1 is established"); + check!(session2.session_state == TCPSessionState::Established, "Session2 is now established"); + + // Sending errors during established + for i in 0..100 { + // Invariants to start + check!(session1.session_state == TCPSessionState::Established, "Session1 is established"); + check!(session2.session_state == TCPSessionState::Established, "Session2 is now established"); + check!(session1.everything_acked(), "Everything acked to start"); + check!(session2.everything_acked(), "Everything acked to start"); + + // Send data on session1 + let data = vec![0; i + 1]; + let to_send_wrapped = session1.process_send(&data); + check!(to_send_wrapped.is_ok(), "Process send is ok"); + let to_send = to_send_wrapped.unwrap(); + check!(!session1.everything_acked(), "Everything not acked yet"); + + // Recv-ing packets with low seq num + let mut wrong_seq = to_send.clone(); + wrong_seq.seq_num = Bytefield32::new(to_send.seq_num.swapped().val() - 1); + let (ack_back, action) = session2.process_recv(&wrong_seq.swapped()); + check!(ack_back.is_none(), "Process recv has no response"); + check!(action == SessionAction::Drop, "Process recv pushes upstream"); + // Recv-ing packets with high seq num + let mut wrong_seq = to_send.clone(); + wrong_seq.seq_num = Bytefield32::new(to_send.seq_num.swapped().val() + 50); + let (ack_back, action) = session2.process_recv(&wrong_seq.swapped()); + check!(ack_back.is_none(), "Process recv has no response"); + check!(action == SessionAction::Drop, "Process recv pushes upstream"); + + // Recv data on session2 (valid) + let (ack_back_wrapped, action) = session2.process_recv(&to_send.swapped()); + check!(ack_back_wrapped.is_some(), "Process recv has response"); + let ack_back = ack_back_wrapped.unwrap(); + check!(action == SessionAction::PushUpstream, "Process recv pushes upstream"); + check!(!session1.everything_acked(), "Everything not acked yet"); + + // Recv-ing packets with low seq num + let mut wrong_seq = ack_back.clone(); + wrong_seq.seq_num = Bytefield32::new(ack_back.seq_num.swapped().val() - 1); + let (to_drop, action) = session1.process_recv(&wrong_seq.swapped()); + check!(to_drop.is_none(), "Process recv has no response"); + check!(action == SessionAction::Drop, "Process recv pushes"); + check!(!session1.everything_acked(), "Everything not acked yet"); + // Recv-ing packets with high seq num + let mut wrong_seq = ack_back.clone(); + wrong_seq.seq_num = Bytefield32::new(ack_back.seq_num.swapped().val() + 50); + let (to_drop, action) = session1.process_recv(&wrong_seq.swapped()); + check!(to_drop.is_none(), "Process recv has no response"); + check!(action == SessionAction::Drop, "Process recv drops"); + check!(!session1.everything_acked(), "Everything not acked yet"); + + // Get ack on session1 + let (no_back, action) = session1.process_recv(&ack_back.swapped()); + check!(no_back.is_none(), "Process recv-ack has no response"); + check!(action == SessionAction::PushUpstream, "Process recv-ack pushes upstream"); + check!(session1.everything_acked(), "Everything acked"); + + // Sending packets with syn flag + let (no_pkt1, action1) = session1.process_recv(&syn_ack.clone().unwrap().swapped()); + let (no_pkt2, action2) = session2.process_recv(&syn.clone().swapped()); + check!(no_pkt1.is_none() && no_pkt2.is_none(), "No packet generated"); + check!(action1 == SessionAction::Drop && action1 == action2, "Drop both packets"); + check!(session1.session_state == TCPSessionState::Established, "Session1 is established"); + check!(session2.session_state == TCPSessionState::Established, "Session2 is now established"); + check!(session1.everything_acked(), "Everything acked"); + check!(session2.everything_acked(), "Everything acked"); + } + + // Close session (4-way handshake), first fin packet + let fin_pkt = session1.close(); + check!(fin_pkt.is_ok(), "Created a fin packet"); + check!(session1.session_state == TCPSessionState::Closing, "Session1 is now closing"); + check!(session2.session_state == TCPSessionState::Established, "Session2 is still established"); + + // Send fin-ack packets back + let (fin_ack_pkt1, action) = session2.process_recv(&fin_pkt.unwrap().swapped()); + check!(fin_ack_pkt1.is_some(), "Created a fin-ack packet"); + check!(action == SessionAction::EndOfStream, "Should be end of stream now"); + check!(session1.session_state == TCPSessionState::Closing, "Session1 is closing"); + check!(session2.session_state == TCPSessionState::Closing, "Session2 is now closing"); + + // Can close session2 now + let fin_pkt = session2.close(); + check!(fin_pkt.is_ok(), "Created a fin packet"); + check!(session1.session_state == TCPSessionState::Closing, "Session1 is closing"); + check!(session2.session_state == TCPSessionState::Closing, "Session2 is closing"); + + // Send fin-ack packets back + let (fin_ack_pkt2, action) = session1.process_recv(&fin_pkt.unwrap().swapped()); + check!(fin_ack_pkt2.is_some(), "Created a fin-ack packet"); + check!(action == SessionAction::EndOfStream, "Should be end of stream"); + check!(session1.session_state == TCPSessionState::Closing, "Session1 is closing"); + check!(session2.session_state == TCPSessionState::Closing, "Session2 is closing"); + + // Receive fin-ack packet + let (no_res, action) = session1.process_recv(&fin_ack_pkt2.unwrap().swapped()); + check!(no_res.is_none(), "No more packets"); + check!(action == SessionAction::EndOfStream, "Should be end of stream now"); + check!(session1.session_state == TCPSessionState::Closed, "Session1 is now closed"); + check!(session2.session_state == TCPSessionState::Closing, "Session2 is closing"); + + // Recv fin-ack + let (no_res, action) = session2.process_recv(&fin_ack_pkt1.unwrap().swapped()); + check!(no_res.is_none(), "No more packets"); + check!(action == SessionAction::EndOfStream, "Should be end of stream now"); + check!(session1.session_state == TCPSessionState::Closed, "Session1 is closed"); + check!(session2.session_state == TCPSessionState::Closed, "Session2 is now closed"); + + test_ok!(); +} \ No newline at end of file diff --git a/kernel/src/network/test.rs b/kernel/src/network/test.rs index f743711..b2a485a 100644 --- a/kernel/src/network/test.rs +++ b/kernel/src/network/test.rs @@ -1,8 +1,11 @@ use alloc::string::String; -use crate::{serial_println, QemuExitCode, exit_qemu}; +use crate::{exit_qemu, serial_println, QemuExitCode}; -use super::{arp_table, arp, bytefield, command_register, devices, dhcp, ethernet, ip, layer, netsync, network_query, processing, raw_array, raw_socket, udp}; +use super::{ + arp, arp_table, bytefield, command_register, devices, dhcp, ethernet, ip, layer, netsync, network_query, processing, raw_array, + raw_socket, tcp, udp, rtl8139, tcp_session, socket, +}; // Normal sync tests (run before starting async processing) pub fn test_sync() -> Result<(), String> { @@ -18,6 +21,9 @@ pub fn test_sync() -> Result<(), String> { netsync::test()?; raw_array::test()?; udp::test()?; + tcp::test()?; + rtl8139::test()?; + tcp_session::test()?; Ok(()) } @@ -35,12 +41,19 @@ pub async fn test_async() { exit_qemu(QemuExitCode::Failed); } - // Test processing (stage 3) - if let Err(err) = raw_socket::test().await { + // Test raw socket (stage 3) + if let Err(err) = raw_socket::test().await { serial_println!("[ERR] {}", err); exit_qemu(QemuExitCode::Failed); } - + + // Test socket (stage 4) + if let Err(err) = socket::test().await { + serial_println!("[ERR] {}", err); + exit_qemu(QemuExitCode::Failed); + } + + serial_println!("Passed all tests"); } #[macro_export] @@ -50,8 +63,8 @@ macro_rules! check { serial_println!("\nfile: {}:{} \x1b[91m[failed]\x1b[0m\n", file!(), line!()); return Err(String::from($err)); } // else { - // serial_println!("\t{}, file: {}:{} \x1b[92m[ok]\x1b[0m", $err, file!(), line!()); - // } + // serial_println!("\t{}, file: {}:{} \x1b[92m[ok]\x1b[0m", $err, file!(), line!()); + //} }; } @@ -66,8 +79,10 @@ macro_rules! test_ok { #[macro_export] macro_rules! mark_as_test { ($description: expr) => { - let to_print = format_args!("Running {} test {}:{}", $description, file!(), line!()).as_str(); - let as_string = to_print.unwrap_or(""); - serial_print!("{: <80}", as_string); + let to_print = format_args!("Testing {}", $description).as_str(); + let to_print2 = format_args!("{}:{}", file!(), line!()).as_str(); + let test_name = to_print.unwrap_or(""); + let test_file = to_print2.unwrap_or(""); + serial_print!("{: <35}{: <45}", test_name, test_file); }; -} \ No newline at end of file +} diff --git a/kernel/src/network/udp.rs b/kernel/src/network/udp.rs index 8577bf3..3cd8467 100644 --- a/kernel/src/network/udp.rs +++ b/kernel/src/network/udp.rs @@ -81,6 +81,7 @@ impl Layer for UDPPacket { } // Save ip packet and read 8 bytes let mut i = 0; + let expected_length = ip_layer.total_length.val(); packet.ip = ip_layer; packet.src_port = Bytefield16::read_inc(&bytevec[i..], &mut i); packet.dest_port = Bytefield16::read_inc(&bytevec[i..], &mut i); @@ -88,6 +89,9 @@ impl Layer for UDPPacket { packet.checksum = Bytefield16::read_inc(&bytevec[i..], &mut i); // assert 8 bytes assert!(i == 8); + if packet.length.val() + IPPacket::packet_size() != expected_length { + return (packet, 0, LayerType::Err); + } // Match the destination port to see if its DHCP let layer_type = match packet.dest_port.val() { // If port 68, send to DHCP layer @@ -205,11 +209,12 @@ pub fn test() -> Result<(), String> { ); // Create another UDP packet + let data = vec![1, 2, 3, 4, 5, 6, 7]; + let payload_size = data.len() as u16; let eth = EthernetPacket::gen(1, 2, EthType::IPv4); - let ip = IPPacket::gen(eth.clone(), 0, Protocol::Udp, 3, 4); + let ip = IPPacket::gen(eth.clone(), payload_size + UDPPacket::packet_size(), Protocol::Udp, 3, 4); let mut udp = UDPPacket::gen(ip, 5, 6, 7); - udp.data = vec![1, 2, 3, 4, 5, 6, 7]; - let payload_size = udp.data.len() as u16; + udp.data = data; check!(udp.src_port.swapped().val() == 5, "Check src port"); check!(udp.dest_port.swapped().val() == 6, "Check dest port"); check!( @@ -236,7 +241,7 @@ pub fn test() -> Result<(), String> { let (_, should_be_err) = full_parse(&serialized); check!(should_be_err.get_type() == LayerType::Err, "UDP packet has less data than promised"); - // Create another UDP packet + // Create another UDP packet with no body, make it smaller than usual let eth = EthernetPacket::gen(1, 2, EthType::IPv4); let ip = IPPacket::gen(eth.clone(), 0, Protocol::Udp, 3, 4); let udp = UDPPacket::gen(ip, 5, 6, 0); diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index 18141f6..de3f7e4 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -10,7 +10,7 @@ use alloc::string::String; /// A TCP application that listens on port 6664 and echos any data sent to it /// Used for testing or learning how to use the socket API pub async fn tcp_echo_server() { - let socket_or_err = Socket::open(SocketType::TCP, 6664, 0).await; + let socket_or_err = Socket::open(SocketType::TCP, 6664, 2).await; if let Err(err) = socket_or_err { println!("[ERR] {:?}", err); return; diff --git a/kernel/src/task/timeout.rs b/kernel/src/task/timeout.rs index a3bfd0d..f7c87da 100644 --- a/kernel/src/task/timeout.rs +++ b/kernel/src/task/timeout.rs @@ -1,5 +1,5 @@ -use alloc::collections::BinaryHeap; -use core::{cell::RefCell, sync::atomic::AtomicU64, task::Waker}; +use core::{sync::atomic::AtomicU64, task::Waker}; +use alloc::collections::BTreeMap; use lazy_static::lazy_static; use x86_64::instructions::interrupts; @@ -24,24 +24,18 @@ impl TimeoutID { /// An entry in the timeout structures struct TimeoutEntry { - /// The id of the timeout - id: TimeoutID, /// What epoch to wake at epochs: u64, /// The waker to use to wake a task waker: Waker, - /// If the timeout was cancelled and we are just ignoring it - cancelled: bool, } impl TimeoutEntry { /// Create a new timeout entry - pub fn new(id: TimeoutID, epochs: u64, waker: Waker) -> Self { + pub fn new(epochs: u64, waker: Waker) -> Self { TimeoutEntry { - id, epochs, waker, - cancelled: false, } } } @@ -71,8 +65,8 @@ impl Ord for TimeoutEntry { } lazy_static! { - /// A binary heap of timeout entrys - static ref TIMEOUT_QUEUE: spin::Mutex>> = spin::Mutex::new(BinaryHeap::new()); + /// A binary heap of timeout entries + static ref TIMEOUT_MAP: spin::Mutex> = spin::Mutex::new(BTreeMap::new()); } /// Read the interrupt counter @@ -86,12 +80,12 @@ pub fn read_interrupt_counter() -> u64 { pub fn register_timeout(after_epochs: u16, waker: Waker) -> TimeoutID { let timeout_id = TimeoutID::new(); interrupts::without_interrupts(|| { - let mut timeout_queue = TIMEOUT_QUEUE.lock(); - timeout_queue.push(RefCell::new(TimeoutEntry::new( - timeout_id, + // insert our timeout entry + let mut timeout_map = TIMEOUT_MAP.lock(); + timeout_map.insert(timeout_id, TimeoutEntry::new( unsafe { INTERRUPT_COUNTER } + after_epochs as u64, waker, - ))); + )); }); timeout_id } @@ -100,15 +94,8 @@ pub fn register_timeout(after_epochs: u16, waker: Waker) -> TimeoutID { pub fn cancel_timeout(id: TimeoutID) { // Without interrupts (no timer interrupts) interrupts::without_interrupts(|| { - // Lock the queue and find our timeout entry - let timeout_queue = TIMEOUT_QUEUE.lock(); - for entry in timeout_queue.iter() { - if entry.borrow().id.0 == id.0 { - // Once found, we cancel the timeout entry - entry.borrow_mut().cancelled = true; - break; - } - } + // Lock the queue and remove our timeout entry + TIMEOUT_MAP.lock().remove(&id); }); } @@ -116,23 +103,31 @@ pub fn cancel_timeout(id: TimeoutID) { /// *Only run from the interrupt context* pub fn poll_timeouts() { // We lock the timeout queue - let mut timeout_queue = TIMEOUT_QUEUE.lock(); + let mut timeout_map = TIMEOUT_MAP.lock(); // Increment the counter unsafe { INTERRUPT_COUNTER += 1 }; - // And continously read timeout entrys - while let Some(timeout_entry) = timeout_queue.peek() { + let mut to_delete: [TimeoutID; 40] = [TimeoutID(0); 40]; + let mut i = 0; + // Continuously read timeout entries + for (timeout_id, timeout_entry) in timeout_map.iter() { // If the timeout entry is expired - if timeout_entry.borrow().epochs <= unsafe { INTERRUPT_COUNTER } { - // And its not cancelled - if !timeout_entry.borrow().cancelled { - // Then we wake the timeout entry's waker - timeout_entry.borrow().waker.wake_by_ref(); + if timeout_entry.epochs <= unsafe { INTERRUPT_COUNTER } { + // Add it to the list + to_delete[i] = *timeout_id; + i += 1; + if i == 40 { + // break before we overflow the array + break; } - // We actually remove the timeout entry if its expired - timeout_queue.pop(); } else { - // If we can't remove any timeouts - we break + // Otherwise break break; } } + for id in to_delete.iter().take(i) { + // Remove the entry from the map + let entry = timeout_map.remove(id).unwrap(); + // And wake the timeout entry's waker + entry.waker.wake(); + } } From 1301f25be8f8e8da2a32b909caa1561991e94154 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sat, 9 Dec 2023 20:05:54 -0500 Subject: [PATCH 30/36] Final bug fixes to socket --- kernel/src/network/socket.rs | 114 ++++++++++++++++++++++++++---- kernel/src/network/tcp_session.rs | 11 +-- 2 files changed, 106 insertions(+), 19 deletions(-) diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 2d95f24..721edfd 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -43,6 +43,7 @@ enum SocketState { /// The socket object for communication with the network /// Will utilize the network stack +#[derive(Debug)] pub struct Socket { socket_type: SocketType, socket_state: SocketState, @@ -56,6 +57,7 @@ pub struct Socket { src_mac: u64, wait_timeout: u16, session_key: u64, + pub urgent: bool, } impl Socket { @@ -109,6 +111,7 @@ impl Socket { dest_mac: 0, src_port, src_address, + urgent: false, src_mac, wait_timeout, session_key: 0, @@ -186,6 +189,7 @@ impl Socket { src_address: self.src_address, src_mac: self.src_mac, wait_timeout: self.wait_timeout, + urgent: self.urgent, session_key, })); } @@ -257,6 +261,8 @@ impl Socket { let mut rtl_dev_config = NET_INFO.config.lock(); // Guard if session exists if rtl_dev_config.tcp_sessions.contains_key(&session_key) { + session_socket.close(); + port_socket.unwrap().close(); return Err(NetworkErrors::PortInUse); } // Compact a template packet @@ -290,6 +296,8 @@ impl Socket { } // We ran out of retries, exiting with an error if !found { + port_socket.unwrap().close(); + session_socket.close(); return Err(NetworkErrors::NonexistentHost); } @@ -307,6 +315,7 @@ impl Socket { src_address: src_ip, src_mac, wait_timeout, + urgent: false, session_key: TCPSession::gen_session_key(dest_ip, dest_port, src_port), }) } @@ -385,9 +394,9 @@ impl Socket { /// Internal function for reading as a UDP socket #[allow(clippy::question_mark)] async fn read_udp(&mut self, size: usize) -> Result, NetworkErrors> { + // Create a result vector + let mut res_vec = vec![]; loop { - // Create a result vector - let mut res_vec = vec![]; // Loop until we get a packet if let Some(pkt_or_err) = self.raw_socket.next().await { // If the packet is an error, read_udp returns an error @@ -415,9 +424,9 @@ impl Socket { /// Internal function for reading as a TCP socket #[allow(clippy::question_mark)] async fn read_tcp(&mut self, size: usize) -> Result, NetworkErrors> { + // Create a result vector + let mut res_vec = vec![]; loop { - // Create a result vector - let mut res_vec = vec![]; // Loop until we get a packet if let Some(pkt_or_err) = self.raw_socket.next().await { // If the packet is an error, read_tcp returns an error @@ -515,7 +524,7 @@ impl Socket { let tcp_session = tcp_session_guard.tcp_sessions.get_mut(&self.session_key).unwrap(); let data_len = min(data.len(), min(tcp_session.window_size as usize, 1380)); // Message pkt is our present to send to our server - let message_pkt = tcp_session.process_send(&data[..data_len]); + let message_pkt = tcp_session.process_send(&data[..data_len], false); // Trim the array to remove the sent data let mut index = 0; data.retain(|_| { @@ -527,8 +536,7 @@ impl Socket { return Err(err); } // Unwrap the message - let tmp = message_pkt.unwrap(); - let message = tmp.serialize(); + let message = message_pkt.unwrap().serialize(); drop(tcp_session_guard); // Wait for the ack -- 5 retries @@ -564,6 +572,7 @@ impl Socket { } } +#[allow(clippy::needless_range_loop)] pub async fn test() -> Result<(), String> { mark_as_test!("Socket (stage 4)"); // UDP socket first @@ -582,7 +591,7 @@ pub async fn test() -> Result<(), String> { check!(data.is_ok(), "Data was received ok"); check!(data.unwrap() == vec![0, 1], "Data is correct"); - for _ in 0..1000 { + for _ in 0..100 { // writing 200 bytes let not_err = send.reliable_write(&mut vec![0; 200]).await; check!(not_err.is_none(), "Writing is not an error"); @@ -608,8 +617,8 @@ pub async fn test() -> Result<(), String> { let mut send = send_wrapped.unwrap(); let mut recv = recv_wrapped.unwrap().unwrap(); - // Send data 10 times - for _ in 0..10 { + // Send data 5 times + for _ in 0..5 { // Writing 200 bytes let not_err = send.reliable_write(&mut vec![0; 200]).await; if not_err.is_some() { @@ -623,12 +632,87 @@ pub async fn test() -> Result<(), String> { send.close().await; recv.close().await; } + server_socket.close().await; + + // Errors with sending in bad states (without listening) + let mut s1 = Socket::open(SocketType::TCP, 60001, 2).await.unwrap(); + let s2 = Socket::connect(SocketType::TCP, my_ip, 59999, 60000, 2).await; + check!(s2.is_err() && s2.unwrap_err() == NetworkErrors::NonexistentHost, "No socket"); + let write_err = s1.reliable_write(&mut vec![0]).await; + check!(write_err.is_some() && write_err.unwrap() == NetworkErrors::BadSocketState, "Write error"); + let read_err = s1.read(1).await; + check!(read_err.is_err() && read_err.unwrap_err() == NetworkErrors::BadSocketState, "Read error"); + // close + s1.close().await; + + // TCP server with multiple clients + let mut server_socket = Socket::open(SocketType::TCP, 60001, 5).await.unwrap(); + let mut clients: Vec = Vec::new(); + for k in 0..10 { + clients.push(Socket::connect(SocketType::TCP, my_ip, 60001, 60002 + k, 5).await.unwrap()); + } + let mut receivers: Vec = Vec::new(); + for _ in 0..10 { + receivers.push(server_socket.listen().await.unwrap().unwrap()); + } + + // Send data once for each socket + for i in 0..10 { + let no_err = clients[i].reliable_write(&mut vec![i as u8]).await; + check!(no_err.is_none(), "Data was well sent"); + } + for i in 0..10 { + let data = receivers[i].read(1).await.unwrap(); + check!(data[0] == i as u8, "Data was well received"); + } + // close each socket + for _ in 0..10 { + clients.pop().unwrap().close().await; + receivers.pop().unwrap().close().await; + } + server_socket.close().await; + + // sending multiple packets and read in chunks + let mut server_socket = Socket::open(SocketType::TCP, 60001, 5).await.unwrap(); + let mut client = Socket::connect(SocketType::TCP, my_ip, 60001, 60000, 5).await.unwrap(); + let mut receiver = server_socket.listen().await.unwrap().unwrap(); + // Send data/read data in chunks, check ordering + for i in 100..110 { + // Sending data + let no_err = client.reliable_write(&mut vec![1; 10]).await; + check!(no_err.is_none(), "Data was well sent"); + let no_err = client.reliable_write(&mut vec![i as u8; i]).await; + check!(no_err.is_none(), "Data was well sent"); + let no_err = client.reliable_write(&mut vec![2; 10]).await; + check!(no_err.is_none(), "Data was well sent"); + let no_err = client.reliable_write(&mut vec![3; 1400]).await; // ~2KB + check!(no_err.is_none(), "Data was well sent"); + + let mut data: Vec = Vec::new(); + data.append(&mut receiver.read(10 + i).await.unwrap()); + check!(data.len() == 10 + i, "Data reading is as expected"); + data.append(&mut receiver.read(1).await.unwrap()); // should read 10 + check!(data.len() == 20 + i, "Data reading, can be more than expected"); + data.append(&mut receiver.read(1400).await.unwrap()); + check!(data.len() == 1420 + i, "Data reading, big packet"); + for (j, val) in data.iter().enumerate() { + if j < 10 { + check!(*val == 1, "Next 10 values are 1"); + } else if j < 10 + i { + check!(*val == i as u8, "Next i values are i"); + } else if j < 20 + i { + check!(*val == 2, "Next 10 values are 2"); + } else { + check!(*val == 3, "Anything after that is 3"); + } + } + } + // close each socket + client.close().await; + receiver.close().await; + server_socket.close().await; - // todo: Errors with sending in bad states (without listening) - // todo: TCP server with multiple clients - // todo: closing with no good state - // todo: sending multiple packets and read in chunks // todo: what happens if we both write and then read... !! - // todo: data comes completely + test_ok!(); } \ No newline at end of file diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index 96296c9..a1c96bd 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -136,7 +136,7 @@ impl TCPSession { } /// A function for generating a packet to sent with data - pub fn process_send(&mut self, data: &[u8]) -> Result { + pub fn process_send(&mut self, data: &[u8], urgent: bool) -> Result { // Check session state if self.session_state != TCPSessionState::Established { return Err(NetworkErrors::BadSocketState); @@ -144,7 +144,10 @@ impl TCPSession { // Clone the template let mut tcp_pkt = self.session_template.clone(); // Set flags and data - tcp_pkt.turn_on_flags(TCP_ACK | TCP_PSH); + tcp_pkt.turn_on_flags(TCP_ACK); + if urgent { + tcp_pkt.turn_on_flags(TCP_PSH); + } tcp_pkt.data = data.to_vec(); // add the data size @@ -376,7 +379,7 @@ pub fn test() -> Result<(), String> { // Send data on session1 let data = vec![0; i + 1]; - let to_send = session1.process_send(&data); + let to_send = session1.process_send(&data, false); check!(to_send.is_ok(), "Process send is ok"); check!(!session1.everything_acked(), "Everything not acked yet"); @@ -478,7 +481,7 @@ pub fn test() -> Result<(), String> { // Send data on session1 let data = vec![0; i + 1]; - let to_send_wrapped = session1.process_send(&data); + let to_send_wrapped = session1.process_send(&data, false); check!(to_send_wrapped.is_ok(), "Process send is ok"); let to_send = to_send_wrapped.unwrap(); check!(!session1.everything_acked(), "Everything not acked yet"); From 8092bd1fe7ecb6d169c2d83d5add8819e04ab6b7 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 10 Dec 2023 16:31:13 -0500 Subject: [PATCH 31/36] bug fixes and merges --- .cargo/config.toml | 2 +- kernel/Cargo.toml | 13 +- kernel/scripts/broadcast_module.py | 27 +++ kernel/scripts/network_stack.py | 57 +++++ kernel/scripts/test.wasm | 329 +++++++++++++++++++++++++++++ kernel/src/allocator.rs | 2 +- kernel/src/interrupts.rs | 46 +++- kernel/src/main.rs | 3 +- kernel/src/network/processing.rs | 4 +- kernel/src/network/raw_socket.rs | 9 + kernel/src/network/rtl8139.rs | 13 +- kernel/src/network/socket.rs | 75 +++++-- kernel/src/network/tcp.rs | 2 +- kernel/src/network/tcp_session.rs | 4 +- kernel/src/network/test.rs | 4 +- kernel/src/task/mod.rs | 2 + kernel/src/task/tcp_echo.rs | 65 +++--- kernel/src/task/test_reader.rs | 61 ++++++ kernel/src/task/udp_echo.rs | 50 ++--- src/main.rs | 4 +- 20 files changed, 654 insertions(+), 118 deletions(-) create mode 100644 kernel/scripts/broadcast_module.py create mode 100644 kernel/scripts/network_stack.py create mode 100644 kernel/scripts/test.wasm create mode 100644 kernel/src/task/test_reader.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index b631240..9b1b4b7 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,5 +1,5 @@ # [build] -# target = "x86_64-unknown-none" +# target = "wasm32-unknown-unknown" # [unstable] # build-std = ["core", "compiler_builtins", "alloc"] diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index ffb5708..ed24cf7 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -18,6 +18,7 @@ linked_list_allocator = "0.10.5" pci = "0.0.1" noto-sans-mono-bitmap = "0.2.0" hashbrown = "0.14.2" +md-5 = { version = "0.10.6", default-features = false } [dependencies.wasmi] version = "0.31.0" @@ -40,15 +41,3 @@ default-features = false version = "0.3.4" default-features = false features = ["alloc"] - -# [[test]] -# name = "should_panic" -# harness = false - -# [[test]] -# name = "stack_overflow" -# harness = false - -# [workspace] -# members = ["wasm-demos"] -# default-members = [".", "wasm-demos"] \ No newline at end of file diff --git a/kernel/scripts/broadcast_module.py b/kernel/scripts/broadcast_module.py new file mode 100644 index 0000000..c5fd9bd --- /dev/null +++ b/kernel/scripts/broadcast_module.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 -u +# Broadcast a WASM module over TCP, then connect the socket to stdio +import sys +import os +import struct +import socket +import time + +if __name__ == "__main__": + if len(sys.argv) != 2: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + sys.exit(1) + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect(("localhost", 7777)) + + wasm_path = sys.argv[1] + file_size = os.path.getsize(wasm_path) + s.sendall(struct.pack(" u16 { + let mut pic1_cmd = Port::::new(0x20); + let mut pic2_cmd = Port::::new(0xA0); + pic1_cmd.write(ocw3); + pic2_cmd.write(ocw3); + let reg1 = pic1_cmd.read(); + let reg2 = pic2_cmd.read(); + ((reg2 as u16) << 8) | (reg1 as u16) +} + +fn pic_get_irr() -> u16 { + unsafe { pic_get_irq_reg(0x0a) } +} +fn pic_get_isr() -> u16 { + unsafe { pic_get_irq_reg(0x0b) } +} + +extern "x86-interrupt" fn keyboard_interrupt_handler(_stack_frame: InterruptStackFrame) { let mut port = Port::new(0x60); let scancode: u8 = unsafe { port.read() }; crate::task::keyboard::add_scancode(scancode); @@ -137,8 +156,23 @@ extern "x86-interrupt" fn keyboard_interrupt_handler(_stack_frame: InterruptStac } } -#[test_case] -fn test_breakpoint_exception() { - // invoke a breakpoint exception - x86_64::instructions::interrupts::int3(); +extern "x86-interrupt" fn handle_irq7(_stack_frame: InterruptStackFrame) { + let _irr = pic_get_irr(); + let isr = pic_get_isr(); + + if isr & (1 << 7) != 0 { + panic!("Non-spurious IRQ7?!"); + } } + +extern "x86-interrupt" fn handle_irq15(_stack_frame: InterruptStackFrame) { + let _irr = pic_get_irr(); + let isr = pic_get_isr(); + + if isr & (1 << 15) != 0 { + panic!("Non-spurious IRQ15?!"); + } + unsafe { + PICS.lock().notify_end_of_interrupt(PIC_1_OFFSET); + } +} \ No newline at end of file diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 9a5433d..523ae9f 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -19,7 +19,7 @@ use kernel::{ println, task::keyboard, task::{executor::Executor, Task}, - task::{tcp_echo, udp_echo}, exit_qemu, + task::{tcp_echo, udp_echo, test_reader}, exit_qemu, }; extern crate alloc; @@ -109,6 +109,7 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { executor.spawn(Task::new(keyboard::print_keypresses())); executor.spawn(Task::new(udp_echo::udp_echo_server())); executor.spawn(Task::new(tcp_echo::tcp_echo_server())); + executor.spawn(Task::new(test_reader::test_reader_server())); executor.run(); } diff --git a/kernel/src/network/processing.rs b/kernel/src/network/processing.rs index c60d250..e47745a 100644 --- a/kernel/src/network/processing.rs +++ b/kernel/src/network/processing.rs @@ -38,7 +38,7 @@ pub struct PacketBuf { } /// An array queue for data to parse static PENDING_DATA: OnceCell> = OnceCell::uninit(); -static MAX_WINDOW_SIZE: u16 = 20; +static MAX_WINDOW_SIZE: u16 = 40; static mut WINDOW_SIZE: u16 = MAX_WINDOW_SIZE; pub fn get_window_size() -> u16 { unsafe { WINDOW_SIZE } @@ -54,7 +54,7 @@ struct PendingProcessingStream { impl PendingProcessingStream { /// Create a new pending process stream fn new() -> Self { - // Initialize the pending data array queue with max size 20 + // Initialize the pending data array queue with max size 40 PENDING_DATA .try_init_once(|| ArrayQueue::new(MAX_WINDOW_SIZE as usize)) .expect("PendingProcessingStream::new should only be called once"); diff --git a/kernel/src/network/raw_socket.rs b/kernel/src/network/raw_socket.rs index 223db3b..0492ac6 100644 --- a/kernel/src/network/raw_socket.rs +++ b/kernel/src/network/raw_socket.rs @@ -33,6 +33,8 @@ pub struct RawSocket { timeout_active: bool, /// The id of the timeout we registered (for cancellation) timeout_id: TimeoutID, + /// Logically closed (will always return NetworkErrors::ClosedSocket) + is_end_of_stream: bool, } impl RawSocket { @@ -58,6 +60,7 @@ impl RawSocket { timeout_in_epochs, timeout_active: false, timeout_id: TimeoutID::new(), + is_end_of_stream: false, }) } @@ -95,6 +98,9 @@ impl Stream for RawSocket { /// Poll for a new packet on the raw socket (it's main purpose) fn poll_next(mut self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_end_of_stream { + return Poll::Ready(Some(Err(NetworkErrors::ClosedSocket))); + } // Try to get a packet let pkt = self.try_get_packet_inner(); @@ -106,6 +112,9 @@ impl Stream for RawSocket { // If we found a packet - cancel any pending timeouts cancel_timeout(self.timeout_id); self.timeout_active = false; + if let Some(Err(NetworkErrors::ClosedSocket)) = pkt { + self.is_end_of_stream = true; + } // And return the packet Poll::Ready(pkt) } else if self.timeout_active { diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index 8c9fe27..fbfe84c 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -168,20 +168,13 @@ fn recv_packet(rtl_dev_info: &RTL8139) { }); } // after receiving the packet, update CAPR and RECV_POS - // increment recv_pos - unsafe { - RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; - } - unsafe { - RECV_POS = (RECV_POS + length) % RX_BUFFER_SIZE; - } - // we and with RX_READ_PTR_MASK to ensure double word alignment + // increment recv_pos by and-ing with RX_READ_PTR_MASK to ensure double word alignment // 4 for header length, 3 is also apparently for dword alignment... unsafe { - RECV_POS = ((RECV_POS + 4) & RX_READ_PTR_MASK) % RX_BUFFER_SIZE; + RECV_POS = (( RECV_POS + 4 + length + 3) & RX_READ_PTR_MASK ) % RX_BUFFER_SIZE; } let mut capr = Port::::new((rtl_dev_info.device_info.io_base.unwrap() + CAPR) as u16); - unsafe { capr.write(RECV_POS - 0x10) }; + unsafe { capr.write(RECV_POS.wrapping_sub(0x10)) }; } else { unsafe { RECV_POS = (RECV_POS + 2) % RX_BUFFER_SIZE; diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 721edfd..0e93650 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -6,7 +6,7 @@ use alloc::{boxed::Box, vec}; use futures_util::StreamExt; use crate::{test_ok, mark_as_test, check, serial_print, serial_println}; -use crate::{network::layer::LayerType, println}; +use crate::network::layer::LayerType; use super::bytefield::Bytefield32; use super::constants::TCP_SYN; @@ -26,6 +26,15 @@ use super::{ udp::UDPPacket, }; +#[allow(clippy::drain_collect)] +fn split_vec(vec: &mut Vec, size: usize) -> Vec { + if vec.len() >= size { + vec.drain(0..size).collect() + } else { + vec.drain(0..vec.len()).collect() + } +} + /// An enum to represent socket type (UDP or TCP) #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum SocketType { @@ -57,6 +66,7 @@ pub struct Socket { src_mac: u64, wait_timeout: u16, session_key: u64, + read_buffer: Vec, pub urgent: bool, } @@ -115,6 +125,7 @@ impl Socket { src_mac, wait_timeout, session_key: 0, + read_buffer: Vec::new(), }), // otherwise return the error from the raw socket construction Err(err) => Err(err), @@ -174,7 +185,6 @@ impl Socket { let session_key = TCPSession::gen_session_key(dest_ip, dest_port, self.src_port); // And a raw socket let raw_socket = RawSocket::new(session_key, max(self.wait_timeout * 18, 1)).unwrap(); - println!("[INFO] Spawned new TCP session"); // And return a new socket object to be the ready socket // the current socket never transitions out of listening return Ok(Some(Socket { @@ -191,6 +201,7 @@ impl Socket { wait_timeout: self.wait_timeout, urgent: self.urgent, session_key, + read_buffer: Vec::new(), })); } } @@ -300,7 +311,6 @@ impl Socket { session_socket.close(); return Err(NetworkErrors::NonexistentHost); } - } // If the raw socket was created successfully, we return a new socket object Ok(Socket { @@ -317,6 +327,7 @@ impl Socket { wait_timeout, urgent: false, session_key: TCPSession::gen_session_key(dest_ip, dest_port, src_port), + read_buffer: Vec::new(), }) } @@ -355,7 +366,7 @@ impl Socket { loop { // Keep reading the stream until we get an error - if let Err(next) = self.read_tcp(0).await { + if let Err(next) = self.read_tcp(1).await { if next == NetworkErrors::ClosedSocket { // If we have a closed stream, we break out completely break 'outer; @@ -379,6 +390,8 @@ impl Socket { } /// Read data from the socket + /// 0: Will return a vector + /// 1: Will return any errors associated with the vector (i.e. a closed socket) pub async fn read(&mut self, size: usize) -> Result, NetworkErrors> { // Check socket state if self.socket_state != SocketState::Ready { @@ -394,8 +407,9 @@ impl Socket { /// Internal function for reading as a UDP socket #[allow(clippy::question_mark)] async fn read_udp(&mut self, size: usize) -> Result, NetworkErrors> { - // Create a result vector - let mut res_vec = vec![]; + if self.read_buffer.len() >= size { + return Ok(split_vec(&mut self.read_buffer, size)); + } loop { // Loop until we get a packet if let Some(pkt_or_err) = self.raw_socket.next().await { @@ -408,11 +422,11 @@ impl Socket { // If we have a matching UDP packet, unwrap it let mut udp_pkt = pkt.unwrap_udp(); // Save the packet into our buffer - res_vec.append(&mut udp_pkt.data); + self.read_buffer.append(&mut udp_pkt.data); } // Our size has exceeded the request, so we can return the resulting data - if size <= res_vec.len() { - return Ok(res_vec); + if size <= self.read_buffer.len() { + return Ok(split_vec(&mut self.read_buffer, size)); } } else { // Our socket timed-out reading, so we return the timeout error @@ -424,28 +438,33 @@ impl Socket { /// Internal function for reading as a TCP socket #[allow(clippy::question_mark)] async fn read_tcp(&mut self, size: usize) -> Result, NetworkErrors> { - // Create a result vector - let mut res_vec = vec![]; + // Return if there is enough data in the read buffer + if self.read_buffer.len() >= size { + return Ok(split_vec(&mut self.read_buffer, size)); + } loop { // Loop until we get a packet if let Some(pkt_or_err) = self.raw_socket.next().await { // If the packet is an error, read_tcp returns an error if let Err(err) = pkt_or_err { + if err == NetworkErrors::ClosedSocket && !self.read_buffer.is_empty() { + return Ok(split_vec(&mut self.read_buffer, size)); + } return Err(err); } let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::Tcp { // If our packet matches our intended type, save the packet into our buffer let mut tcp_pkt = pkt.unwrap_tcp(); - res_vec.append(&mut tcp_pkt.data); + self.read_buffer.append(&mut tcp_pkt.data); // If our packet has PSH, we ignore the size request and immediately push to the application if tcp_pkt.get_flags() & TCP_PSH != 0 { - return Ok(res_vec); + return Ok(split_vec(&mut self.read_buffer, size)); } } // Our size has exceeded the request, so we can return the resulting data - if size <= res_vec.len() { - return Ok(res_vec); + if size <= self.read_buffer.len() { + return Ok(split_vec(&mut self.read_buffer, size)); } } else { // Our socket timed-out reading, so we return the timeout error @@ -591,6 +610,12 @@ pub async fn test() -> Result<(), String> { check!(data.is_ok(), "Data was received ok"); check!(data.unwrap() == vec![0, 1], "Data is correct"); + send.reliable_write(&mut vec![2, 3]).await; + let data1 = recv.read(1).await.unwrap(); + let data2 = recv.read(1).await.unwrap(); + check!(data1[0] == 2, "Data was received in pieces [1]"); + check!(data2[0] == 3, "Data was received in pieces [2]"); + for _ in 0..100 { // writing 200 bytes let not_err = send.reliable_write(&mut vec![0; 200]).await; @@ -691,10 +716,10 @@ pub async fn test() -> Result<(), String> { let mut data: Vec = Vec::new(); data.append(&mut receiver.read(10 + i).await.unwrap()); check!(data.len() == 10 + i, "Data reading is as expected"); - data.append(&mut receiver.read(1).await.unwrap()); // should read 10 - check!(data.len() == 20 + i, "Data reading, can be more than expected"); - data.append(&mut receiver.read(1400).await.unwrap()); - check!(data.len() == 1420 + i, "Data reading, big packet"); + data.append(&mut receiver.read(1).await.unwrap()); + check!(data.len() == 11 + i, "Data reading, can be in pieces"); + data.append(&mut receiver.read(1409).await.unwrap()); + check!(data.len() == 1420 + i, "Data reading, over multiple packets"); for (j, val) in data.iter().enumerate() { if j < 10 { check!(*val == 1, "Next 10 values are 1"); @@ -707,12 +732,20 @@ pub async fn test() -> Result<(), String> { } } } - // close each socket + // reading after a socket has closed on one end + client.reliable_write(&mut vec![1; 10]).await; client.close().await; + let data = receiver.read(5).await; + check!(data.is_ok() && data.unwrap().len() == 5, "Data has no errors until we try to read a closed socket"); + let data = receiver.read(6).await; + check!(data.is_ok() && data.unwrap().len() == 5, "Data has no errors until we try to read a closed socket"); + let data = receiver.read(1).await; + check!(data.is_err() && data.unwrap_err() == NetworkErrors::ClosedSocket, "No more data in the stream"); receiver.close().await; server_socket.close().await; // todo: what happens if we both write and then read... !! - + // the polling of the writing could cause read to break! -- we are ignoring this for now... + test_ok!(); } \ No newline at end of file diff --git a/kernel/src/network/tcp.rs b/kernel/src/network/tcp.rs index b7dd833..58a337a 100644 --- a/kernel/src/network/tcp.rs +++ b/kernel/src/network/tcp.rs @@ -74,7 +74,7 @@ impl TCPPacket { flags: Bytefield16::new(5 << 12), // sliding window :(, pain to implement... // we can just allow unlimited data... Also we would like to keep track of how much we are allowed to send - sliding_window: Bytefield16::new(u16::MAX), + sliding_window: Bytefield16::new(3000), checksum: Bytefield16::new(0), urgent: Bytefield16::new(0), // we never provide options because we basic diff --git a/kernel/src/network/tcp_session.rs b/kernel/src/network/tcp_session.rs index a1c96bd..f7c599f 100644 --- a/kernel/src/network/tcp_session.rs +++ b/kernel/src/network/tcp_session.rs @@ -157,7 +157,8 @@ impl TCPSession { tcp_pkt.seq_num = Bytefield32::new(self.sent_data_amount); self.sent_data_amount += data.len() as u32; tcp_pkt.ack_num = Bytefield32::new(self.recv_data_amount); - + tcp_pkt.sliding_window = Bytefield16::new(self.window_size); + // Calculate checksums tcp_pkt.ip.calculate_checksum(); tcp_pkt.calculate_checksum(); @@ -180,6 +181,7 @@ impl TCPSession { self.window_size = request.sliding_window.val(); let mut response = self.session_template.clone(); // Setting length to be 24 [sizeof(TCP-header) + sizeof(options)] + 20 [sizeof(IP-header)] + response.sliding_window = Bytefield16::new(self.window_size); response.ip.total_length = Bytefield16::new(TCPPacket::packet_size() + TCPPacket::options_size() + IPPacket::packet_size()); // todo: check seq-num and ack-num before doing these state changes if has_fin_flag && self.session_state == TCPSessionState::Established { diff --git a/kernel/src/network/test.rs b/kernel/src/network/test.rs index b2a485a..8701101 100644 --- a/kernel/src/network/test.rs +++ b/kernel/src/network/test.rs @@ -62,9 +62,7 @@ macro_rules! check { if !$status { serial_println!("\nfile: {}:{} \x1b[91m[failed]\x1b[0m\n", file!(), line!()); return Err(String::from($err)); - } // else { - // serial_println!("\t{}, file: {}:{} \x1b[92m[ok]\x1b[0m", $err, file!(), line!()); - //} + } }; } diff --git a/kernel/src/task/mod.rs b/kernel/src/task/mod.rs index c599934..6fe3ec7 100644 --- a/kernel/src/task/mod.rs +++ b/kernel/src/task/mod.rs @@ -9,6 +9,8 @@ pub mod simple_executor; pub mod tcp_echo; pub mod timeout; pub mod udp_echo; +pub mod test_reader; +pub mod wasm_oneshot; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct TaskId(u64); diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index de3f7e4..3bd4370 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -28,38 +28,9 @@ pub async fn tcp_echo_server() { let mut socket = socket_or_err.unwrap().unwrap(); loop { // Continuously read from the socket - let data_or_err = socket.read(0).await; - if let Ok(mut data) = data_or_err { - if data.is_empty() { - // continue if we didn't read any data - continue; - } - // Get the user's message - let user_message = String::from_utf8(data.clone()); - if let Ok(message) = user_message { - // Print it out - print!("[USER] {}", message); - if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { - // If their message is quit or exit, we close the connection - let mut exit_message = "Closing socket...\n".as_bytes().to_vec(); - let is_err = socket.reliable_write(&mut exit_message).await; - if is_err.is_some() { - println!("[ERR] Writing error {:?}", is_err.unwrap()); - } - socket.close().await; - println!("Closed socket"); - break; - } - } - // Echo back the data from the socket - let res_or_err = socket.reliable_write(&mut data).await; - if let Some(err) = res_or_err { - // Writing - print error, close the socket, and break from the loop - println!("[ERR] (Writing): {:?}", err); - socket.close().await; - break; - } - } else if let Err(err) = data_or_err { + // let (mut data, error) = socket.read(4).await; + let read_result = socket.read(4).await; // ! + if let Err(err) = read_result { // ! if err == NetworkErrors::Timeout { // If we timed-out, we just loop again reading the socket continue; @@ -69,6 +40,36 @@ pub async fn tcp_echo_server() { socket.close().await; break; } + let mut data = read_result.unwrap(); // ! + if data.is_empty() { + // continue if we didn't read any data + continue; + } + // Get the user's message + let user_message = String::from_utf8(data.clone()); + if let Ok(message) = user_message { + // Print it out + print!("{}", message); + if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { + // If their message is quit or exit, we close the connection + let mut exit_message = "Closing socket...\n".as_bytes().to_vec(); + let is_err = socket.reliable_write(&mut exit_message).await; + if is_err.is_some() { + println!("[ERR] Writing error {:?}", is_err.unwrap()); + } + socket.close().await; + println!("Closed socket"); + break; + } + } + // Echo back the data from the socket + let res_or_err = socket.reliable_write(&mut data).await; + if let Some(err) = res_or_err { + // Writing - print error, close the socket, and break from the loop + println!("[ERR] (Writing): {:?}", err); + socket.close().await; + break; + } } } println!("[INFO] Finished up all allocated sockets for echoing"); diff --git a/kernel/src/task/test_reader.rs b/kernel/src/task/test_reader.rs new file mode 100644 index 0000000..2addffe --- /dev/null +++ b/kernel/src/task/test_reader.rs @@ -0,0 +1,61 @@ +use crate::{ + network::{ + errors::NetworkErrors, + socket::{Socket, SocketType}, + }, + println, serial_println, +}; +use alloc::vec::Vec; +use md5::{Md5, Digest}; + +/// WASM One Shot format: +/// - 32-bit little-endian header indicating module length +/// - WASM binary module +#[allow(clippy::question_mark)] +async fn get_wasm_file(mut stream: Socket) -> Result, NetworkErrors> { + let data = stream.read(4).await; // ! + if let Err(err) = data { // ! + return Err(err); + } + let len_bytes = data.unwrap(); // ! + assert!(len_bytes.len() == 4); + let mut bytes: [u8; 4] = [0; 4]; + for (i, byte) in len_bytes.iter().enumerate() { + bytes[i] = *byte; + } + let file_size = u32::from_le_bytes(bytes); + serial_println!("File is {} bytes", file_size); + let data = stream.read(file_size as usize).await; // ! + if let Err(err) = data { // ! + if err != NetworkErrors::ClosedSocket { + return Err(err); + } + return Err(NetworkErrors::FeatureNotAvailableYet) // ! -- this was the bug, if we read an error instead of data + } + Ok(data.unwrap()) // ! +} + +pub async fn test_reader_server() { + let socket_or_err = Socket::open(SocketType::TCP, 7774, 2).await; + if let Err(err) = socket_or_err { + println!("[ERR] {:?}", err); + return; + } + let mut socket_gen = socket_or_err.unwrap(); + // Allow 10 sockets to form (we don't have threads so only listening to one at a time) + for _ in 0..10 { + // Listen for a new session socket to form + let socket = socket_gen.listen().await.unwrap().unwrap(); + println!("[WASM] Stream started."); + let file_data = get_wasm_file(socket).await.unwrap(); + + let checksum = Md5::digest(&file_data); + crate::serial_println!("GOT MD5 CHECKSUM: {:02x}", checksum); + let expected_value = 0xd94ef4499ac16f70f559e4e0ef70173b_u128; + let real_value = u128::from_be_bytes(checksum.as_slice().try_into().unwrap()); + assert_eq!(expected_value, real_value); + + println!("[WASM] Stream finished -- Passed Checksum"); + } + println!("[INFO] Finished up all allocated sockets for echoing"); +} \ No newline at end of file diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index b13cd63..b06d5bb 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -32,34 +32,34 @@ pub async fn udp_echo_server() { } loop { // Loop trying to read from the socket - let data_or_err = socket.read(0).await; - if let Ok(mut data) = data_or_err { - // Parse the data we read - let user_message = String::from_utf8(data.clone()); - if let Ok(message) = user_message { - // If we can read the user message, print it - print!("[USER] {}", message); - if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { - // If exit or quit, we close the socket - println!("Closing socket"); - let _ = socket.write(&mut ("Closing socket...\n".as_bytes().to_vec())).await; - socket.close().await; - return; - } - } - // Echo the data - let res_or_err = socket.write(&mut data).await; - if let Err(err) = res_or_err { - // If we got an error -> print, close, exit - println!("[ERR] {:?}", err); - socket.close().await; - break; - } - } else if let Err(err) = data_or_err { + let read_result = socket.read(4).await; // ! + if let Err(err) = read_result { // ! if err == NetworkErrors::Timeout { - // If error is just timeout, we can continue trying to read + // If we timed-out, we just loop again reading the socket continue; } + // Otherwise print the error, close the socket, and break from the loop + println!("[ERR] (Reading): {:?}", err); + socket.close().await; + break; + } + let mut data = read_result.unwrap(); // ! + // Parse the data we read + let user_message = String::from_utf8(data.clone()); + if let Ok(message) = user_message { + // If we can read the user message, print it + print!("{}", message); + if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { + // If exit or quit, we close the socket + println!("Closing socket"); + let _ = socket.reliable_write(&mut ("Closing socket...\n".as_bytes().to_vec())).await; + socket.close().await; + return; + } + } + // Echo the data + let res_or_err = socket.reliable_write(&mut data).await; + if let Some(err) = res_or_err { // If we got an error -> print, close, exit println!("[ERR] {:?}", err); socket.close().await; diff --git a/src/main.rs b/src/main.rs index d0572f2..3f6a14d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,11 +16,11 @@ fn main() { // Can send UDP from localhost:5555 to SOUP:5554 // Also sends TCP from localhost:6666 to SOUP:6664 cmd.arg("-netdev") - .arg("user,id=net0,hostfwd=udp::5555-:5554,hostfwd=tcp::6666-:6664"); + .arg("user,id=net0,hostfwd=udp::5555-:5554,hostfwd=tcp::6666-:6664,hostfwd=tcp::7777-:7774"); // Making sure we have the rtl8139 as a hardware resource cmd.arg("-device").arg("rtl8139,netdev=net0,mac=00:11:22:33:44:55"); // Recording the network traffic in order to conduct debugging and analysis - cmd.arg("-object").arg("filter-dump,id=f1,netdev=net0,file=/tmp/dump.pcap"); + cmd.arg("-object").arg("filter-dump,id=f1,netdev=net0,file=./dump.pcap"); // Turn on debug output cmd.arg("-serial").arg("stdio"); cmd.arg("-device").arg("isa-debug-exit,iobase=0xf4,iosize=0x04"); From 7bd385d4ecd34de38b90597e007bba9beb7b172d Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 10 Dec 2023 18:58:17 -0500 Subject: [PATCH 32/36] Integration changes --- kernel/Cargo.toml | 2 + kernel/scripts/add.wasm | Bin 0 -> 263 bytes kernel/scripts/complex.wasm | Bin 0 -> 559 bytes kernel/scripts/send_wasm.py | 36 ++ kernel/scripts/sub.wasm | Bin 0 -> 263 bytes .../{broadcast_module.py => tcp_check.py} | 4 +- .../{test.wasm => tcp_test_payload.txt} | 0 kernel/src/apic.rs | 90 ++++ kernel/src/lib.rs | 6 +- kernel/src/memory_stealer.rs | 487 ++++++++++++++++++ kernel/src/serial.rs | 2 +- 11 files changed, 623 insertions(+), 4 deletions(-) create mode 100755 kernel/scripts/add.wasm create mode 100755 kernel/scripts/complex.wasm create mode 100644 kernel/scripts/send_wasm.py create mode 100755 kernel/scripts/sub.wasm rename kernel/scripts/{broadcast_module.py => tcp_check.py} (83%) rename kernel/scripts/{test.wasm => tcp_test_payload.txt} (100%) create mode 100644 kernel/src/apic.rs create mode 100644 kernel/src/memory_stealer.rs diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index ed24cf7..ca6ef26 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -19,6 +19,8 @@ pci = "0.0.1" noto-sans-mono-bitmap = "0.2.0" hashbrown = "0.14.2" md-5 = { version = "0.10.6", default-features = false } +arrayvec = { version = "0.7.4", default-features = false } +anyhow = { version = "1.0.75", default-features = false } [dependencies.wasmi] version = "0.31.0" diff --git a/kernel/scripts/add.wasm b/kernel/scripts/add.wasm new file mode 100755 index 0000000000000000000000000000000000000000..1971deb16b7ea77eef5cca9e1ceb96c1cf98bca9 GIT binary patch literal 263 zcmY+7!AiqG6ae4bq>XKah&OLR5R`UH8nNQhgWs{fWS?m?y9v7+FsJm>{4@)mJk4B& z0rCa`026!{^Ne#O5Ec=_)$%5q@%422fpmudOt$gTu<6|-AU4!mfVEPZ0x7mS!c?gv z)1WHy9EtS`6UZ@#&lGNBONP(0zb?5($dvMd>Mu3)tqttpZ8f-Fk2O1wmuIcnd)Bg= zaN$ND>bw2!`LHXbZJT50ru@D*a8c@~s?5uxe3biA7B3I5%>ucGgZf|!qhl{NnHd9B T9m}TcE9(5_!nci;{2hJ)I}k@c literal 0 HcmV?d00001 diff --git a/kernel/scripts/complex.wasm b/kernel/scripts/complex.wasm new file mode 100755 index 0000000000000000000000000000000000000000..d62f68b8bf06305e9f65575abefbbbf28afaebaa GIT binary patch literal 559 zcma)3O>WdM6n@XnG^HUL#164>EFpo~RA~@J?CM2(4dYDAbR5SMVvp33kkVZ*&}$%Z z1P;OhkhnxWOj)pkh2M|m_x(JB@;w3orudOe6Hb&M%qYUMgBNr{aeMn6vI+jq=}B_X z>wZ0M2_(H*Ie;T6TNRYluBDhsd9BqT8|5`6N8bp3C74VC3L<`{1*I2^Ffa@)PsH8* zuOD>SxBqMk!l?LSIL#p=I}{@DGp9yOh7*wwBE8Cip+FYn<4FvAWN~|Z+eP00yyVc? z9g^LT{y~V^88saTacBz4K%gmhydb6+^Wi0*bw$Ev47&oOU0})X0gZ{%C>wvvbt1PW zIuXF2D@YdIY>miW20U@9*YQ*K`%RxC%%$`}Ejl@@S1#xgE}snJy4@`F=<#VR7utK> zmdzGtbK=tT~;67z{xzQajAp+q*d6A b+LPmXzX__bx?I|IqpUwZ^sA*S^-XvHRuh$~ literal 0 HcmV?d00001 diff --git a/kernel/scripts/send_wasm.py b/kernel/scripts/send_wasm.py new file mode 100644 index 0000000..8e61fbd --- /dev/null +++ b/kernel/scripts/send_wasm.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 -u +# Broadcast a WASM module over TCP, then connect the socket to stdio +import sys +import os +import struct +import socket +import time + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + sys.exit(1) + + # Remove once there are wasm files with less/more than two arguments + if len(sys.argv) != 4: # + print("Remove this section if testing wasm files with less/more than 2 arguments") + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) # + sys.exit(1) # + # Remove once there are wasm files with less/more than two arguments + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect(("localhost", 7777)) + + wasm_path = sys.argv[1] + file_size = os.path.getsize(wasm_path) + s.sendall(struct.pack("C*$ Sjz!b;6?J}d<=e&z{tmx8sYgEm literal 0 HcmV?d00001 diff --git a/kernel/scripts/broadcast_module.py b/kernel/scripts/tcp_check.py similarity index 83% rename from kernel/scripts/broadcast_module.py rename to kernel/scripts/tcp_check.py index c5fd9bd..188a2f1 100644 --- a/kernel/scripts/broadcast_module.py +++ b/kernel/scripts/tcp_check.py @@ -8,7 +8,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) sys.exit(1) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -16,7 +16,7 @@ wasm_path = sys.argv[1] file_size = os.path.getsize(wasm_path) - s.sendall(struct.pack(" LocalApic { + LocalApic { base_address } + } + + pub fn enable(&self) { + // unsafe { + // // Set the APIC base address in the MSR + // asm!("wrmsr" :: "c"(0x1b), "a"(self.base_address & 0xFFFFFFFF), "d"(self.base_address >> 32)); + // } + } + + /// Enable interrupt handler + pub fn set_interrupt_handler(&self, vector: u8, handler: fn()) {} + + /// Enable a specific interrupt vector + pub fn enable_interrupt(&self, vector: u8) { + // Add code to enable a specific interrupt vector + } + + /// Send an interrupt to another process + pub fn send_interrupt_process(&self, cpu_id: u64, vector: u8) { + let destination = cpu_id << 24; + let icr_register = self.base_address + 0x300; + let icr_value = destination | (vector as u64); + + unsafe { core::ptr::write_volatile(icr_register as *mut u32, icr_value as u32) } + } +} + +/// Check if the CPU architecture supports APIC +pub fn check_apic() -> bool { + let cpuid = CpuId::new(); + + if let Some(features) = cpuid.get_feature_info() { + features.has_apic() + } else { + false + } +} + +/// Read the current timestamp +pub fn get_time_stamp() -> u64 { + unsafe { x86_64::_rdtsc() } +} \ No newline at end of file diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index c1ffcfc..8028373 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -2,9 +2,13 @@ #![cfg_attr(test, no_main)] #![feature(custom_test_frameworks)] #![feature(const_mut_refs)] +#![feature(impl_trait_in_assoc_type)] +#![feature(type_name_of_val)] +#![feature(type_alias_impl_trait)] +#![feature(abi_x86_interrupt)] #![test_runner(crate::test_runner)] #![reexport_test_harness_main = "test_main"] -#![feature(abi_x86_interrupt)] + #![allow(clippy::missing_safety_doc)] #![allow(clippy::let_and_return)] #![allow(clippy::new_without_default)] diff --git a/kernel/src/memory_stealer.rs b/kernel/src/memory_stealer.rs new file mode 100644 index 0000000..d3084c6 --- /dev/null +++ b/kernel/src/memory_stealer.rs @@ -0,0 +1,487 @@ +use arrayvec::ArrayVec; +use core::alloc::Layout; +use core::cell::{RefCell, RefMut}; +use itertools::Either; +use x86_64::structures::paging::mapper::Mapper; +use x86_64::structures::paging::{ + frame::PhysFrameRange, FrameAllocator, Page, PageSize, PageTableFlags, PhysFrame, Size4KiB, +}; +use x86_64::{PhysAddr, VirtAddr}; + +use crate::allocator::STEAL_START; + +/// A region of physical memory addresses +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct PhysAddrRegion { + start: u64, + len: u64, +} + +impl PhysAddrRegion { + pub fn new(start: u64, len: u64) -> Self { + Self { start, len } + } + + pub fn start(&self) -> u64 { + self.start + } + + pub fn start_phys(&self) -> PhysAddr { + PhysAddr::new(self.start()) + } + + pub fn len(&self) -> u64 { + self.len + } + + pub fn end(&self) -> u64 { + self.start + self.len + } + + pub fn end_phys(&self) -> PhysAddr { + PhysAddr::new(self.end()) + } + + pub fn is_empty(&self) -> bool { + self.len == 0 + } +} + +/// Contains the next free virtual address, used for mapping stolen memory. +#[derive(Debug)] +pub struct StolenVirtTidemark(VirtAddr); + +impl StolenVirtTidemark { + /// SAFETY: This should only be called once or else it may overwrite previously-mapped memory. + pub unsafe fn new() -> Self { + Self(VirtAddr::new(STEAL_START as u64)) + } + + pub fn get(&self) -> VirtAddr { + self.0 + } + + // TODO: maybe this should do bound checks I guess... realistically this isn't an issue + pub fn bump(&mut self, amount: u64) { + self.0 += amount; + } +} + +#[derive(Debug)] +pub struct StealerInfo, F: FrameAllocator> { + pub mapper: M, + pub frame_allocator: F, + pub tidemark: StolenVirtTidemark, +} + +impl, F: FrameAllocator> StealerInfo { + pub fn new(mapper: M, frame_allocator: F, tidemark: StolenVirtTidemark) -> Self { + Self { + mapper, + frame_allocator, + tidemark, + } + } +} + +/// Helper trait to simplify generic bounds on the associated mapper/frame allocator types. +pub trait AsStealerInfo { + type M: Mapper; + type F: FrameAllocator; + + fn get_mut(&self) -> RefMut<'_, StealerInfo>; +} + +impl, F: FrameAllocator> AsStealerInfo + for RefCell> +{ + type M = M; + type F = F; + + fn get_mut(&self) -> RefMut<'_, StealerInfo> { + RefCell::borrow_mut(self) + } +} + +/// A "memory stealer" is an object that is given an early opportunity to search for desirable +/// regions of physical memory, and steal parts of them as it sees fit. For example, the RTL8139 +/// driver will steal the first contiguous 12KB region of memory that it finds that fits within a +/// 32-bit physical address space. +pub trait MemoryStealer { + type Iter<'a, I: Iterator>: Iterator + where + Self: 'a; + + /// This method will be called for each memory region that is available to steal. Return an + /// iterator of the remaining sub-regions after any memory is stolen. The returned regions + /// should be a non-overlapping subset of the input region. This method must properly handle + /// empty regions, and may return empty regions as well (which will be ignored). + /// + /// SAFETY: The caller must ensure that provided region consists of available, unmapped physical + /// memory, and the return value must be respected (in that only the returned subregions may + /// still be mapped). Furthermore, the callee must ensure to only return unused subregions of + /// the input region. + unsafe fn steal>(&mut self, regions: I) + -> Self::Iter<'_, I>; +} + +/// Attempt to steal a contiguous region of physical memory with the given size and alignment. +/// Returns a tuple of `(stolen_region, left_remaining_region, right_remaining_region)` +pub fn try_steal( + region: PhysAddrRegion, + layout: Layout, +) -> Option<(PhysAddrRegion, PhysAddrRegion, PhysAddrRegion)> { + let size = layout.size() as u64; + let align = layout.align() as u64; + let start = region.start_phys().align_up(align).as_u64(); + let end = start + size; + if start + size <= region.end() { + Some(( + PhysAddrRegion::new(start, end - start), + PhysAddrRegion::new(region.start(), start - region.start()), + PhysAddrRegion::new(end, region.end() - end), + )) + } else { + None + } +} + +/// Like `try_steal()`, but returns an _iterator_ of the remaining un-stolen memory regions. +pub fn try_steal_iter( + region: PhysAddrRegion, + layout: Layout, +) -> (Option, impl Iterator) { + match try_steal(region, layout) { + Some((stolen, left, right)) => (Some(stolen), Either::Left([left, right].into_iter())), + None => (None, Either::Right([region].into_iter())), + } +} + +/// Like `try_steal()`, but steals several frames of the given size. +pub fn try_steal_frames( + region: PhysAddrRegion, + num_frames: usize, +) -> Option<(PhysFrameRange, PhysAddrRegion, PhysAddrRegion)> { + let layout = Layout::from_size_align(num_frames * S::SIZE as usize, S::SIZE as usize).unwrap(); + let (stolen_region, left, right) = try_steal(region, layout)?; + let start = PhysFrame::from_start_address(PhysAddr::new(stolen_region.start())).unwrap(); + let end = PhysFrame::from_start_address(PhysAddr::new(stolen_region.end())).unwrap(); + Some((PhysFrame::range(start, end), left, right)) +} + +/// Like `try_steal_frames()`, but returns an _iterator_ of the remaining un-stolen memory regions. +pub fn try_steal_frames_iter( + region: PhysAddrRegion, + num_frames: usize, +) -> ( + Option>, + impl Iterator, +) { + match try_steal_frames(region, num_frames) { + Some((stolen, left, right)) => (Some(stolen), Either::Left([left, right].into_iter())), + None => (None, Either::Right([region].into_iter())), + } +} + +/// A region filter that lets you choose whether or not to steal a region. +pub trait RegionFilter { + fn filter(&self, region: PhysAddrRegion) -> bool; +} + +/// The default filter, accepts any region. +#[derive(Debug)] +pub struct NoOpRegionFilter; + +impl RegionFilter for NoOpRegionFilter { + fn filter(&self, _region: PhysAddrRegion) -> bool { + true + } +} + +/// A region filter that only accepts regions whose entire physical address range (including the +/// address one past the end of the region) fit within 32 bits +#[derive(Debug)] +pub struct RegionFilter32; + +impl RegionFilter for RegionFilter32 { + fn filter(&self, region: PhysAddrRegion) -> bool { + region.end() <= u32::MAX as u64 + } +} + +/// A region filter that only accepts regions whose entire physical address range (including the +/// address one past the end of the region) fit within 16 bits +#[derive(Debug)] +pub struct RegionFilter16; + +impl RegionFilter for RegionFilter16 { + fn filter(&self, region: PhysAddrRegion) -> bool { + region.end() <= u16::MAX as u64 + } +} + +/// A memory stealer that attempts to steal a fixed-sized range of contiguous bytes +#[derive(Clone, Debug)] +pub struct RawRangeStealer { + layout: Layout, + addr: Option, + region_filter: F, +} + +impl RawRangeStealer { + pub fn new(layout: Layout) -> Self { + Self { + layout, + addr: None, + region_filter: NoOpRegionFilter, + } + } +} + +impl RawRangeStealer { + pub fn with_filter(layout: Layout, region_filter: F) -> Self { + Self { + layout, + addr: None, + region_filter, + } + } + + pub fn addr(&self) -> Option { + self.addr + } +} + +impl MemoryStealer for RawRangeStealer { + type Iter<'a, I: Iterator> = impl Iterator + where + Self: 'a; + + unsafe fn steal>( + &mut self, + regions: I, + ) -> Self::Iter<'_, I> { + regions.flat_map(|region| { + if self.addr.is_some() { + // we already stole once, don't steal again + return Either::Left([region].into_iter()); + } + let Some((stolen_range, left, right)) = try_steal(region, self.layout) else { + // we couldn't find a range that fits + return Either::Left([region].into_iter()); + }; + if !self.region_filter.filter(region) { + // we were rejectec by the region filter + return Either::Left([region].into_iter()); + } + + // by here, we've decided this is an acceptable region! + self.addr = Some(stolen_range.start_phys()); + + // return the unused regions + Either::Right([left, right].into_iter()) + }) + } +} + +/// A memory stealer that attempts to steal AND MAP a fixed-sized range of contiguous frames +#[derive(Clone, Debug)] +pub struct FixedFrameStealer<'x, const NFRAMES: usize, X, F = NoOpRegionFilter> { + addr: Option<(PhysAddr, VirtAddr)>, + info: &'x X, + region_filter: F, +} + +impl<'x, const NFRAMES: usize, X> FixedFrameStealer<'x, NFRAMES, X> { + pub fn new(info: &'x X) -> Self { + Self { + addr: None, + info, + region_filter: NoOpRegionFilter, + } + } +} + +impl<'x, const NFRAMES: usize, X, F> FixedFrameStealer<'x, NFRAMES, X, F> { + pub fn with_filter(info: &'x X, region_filter: F) -> Self { + Self { + addr: None, + info, + region_filter, + } + } + + pub fn addr(&self) -> Option<(PhysAddr, VirtAddr)> { + self.addr + } + + /// Get the physical address of the stolen region, along with the contents as a pointer to a + /// slice of bytes. Returns `None` if stealing was unsuccessful. + pub fn into_ptr(self) -> Option<(PhysAddr, *mut [u8])> { + let (phys, virt) = self.addr?; + let ptr = core::ptr::slice_from_raw_parts_mut( + virt.as_mut_ptr(), + NFRAMES * Size4KiB::SIZE as usize, + ); + Some((phys, ptr)) + } + + /// Get the physical address of the stolen region, along with the contents as a slice of bytes. + /// Returns `None` if stealing was unsuccessful. + pub fn into_buf<'a>(self) -> Option<(PhysAddr, &'a mut [u8])> { + let (phys, ptr) = self.into_ptr()?; + Some((phys, unsafe { &mut *ptr })) + } +} + +impl MemoryStealer + for FixedFrameStealer<'_, NFRAMES, X, F> +{ + type Iter<'a, I: Iterator> = impl Iterator + where + Self: 'a; + + unsafe fn steal>( + &mut self, + regions: I, + ) -> Self::Iter<'_, I> { + regions.flat_map(|region| { + if self.addr.is_some() { + // we already stole once, don't steal again + return Either::Left([region].into_iter()); + } + let Some((stolen_range, left, right)) = try_steal_frames(region, NFRAMES) else { + // we couldn't find a range that fits + return Either::Left([region].into_iter()); + }; + if !self.region_filter.filter(region) { + // we were rejectec by the region filter + return Either::Left([region].into_iter()); + } + + // by here, we've decided this is an acceptable region! let's map it + let inner_info = &mut *self.info.get_mut(); + let old_tidemark = inner_info.tidemark.get(); + + let phys = stolen_range.start.start_address(); + let virt = old_tidemark.align_up(4096u64); + + let flags = PageTableFlags::PRESENT | PageTableFlags::WRITABLE; + let first_page = Page::from_start_address(virt).unwrap(); + let page_range = Page::range(first_page, first_page + NFRAMES as u64); + for (page, frame) in page_range.into_iter().zip(stolen_range.into_iter()) { + unsafe { + inner_info + .mapper + .map_to(page, frame, flags, &mut inner_info.frame_allocator) + } + .unwrap() + .flush(); + } + inner_info + .tidemark + .bump(page_range.end.start_address() - old_tidemark); + self.addr = Some((phys, virt)); + + // return the unused regions + Either::Right([left, right].into_iter()) + }) + } +} + +/// Allocates the 3 frames needed for the recursive page tables +#[derive(Debug, Default)] +pub struct RecursiveFrameStealer { + frames: RefCell>, +} + +impl RecursiveFrameStealer { + pub fn new() -> Self { + Default::default() + } +} + +impl MemoryStealer for &RecursiveFrameStealer { + type Iter<'a, I: Iterator> = impl Iterator where Self: 'a; + + unsafe fn steal>( + &mut self, + regions: I, + ) -> Self::Iter<'_, I> { + regions.flat_map(|region| { + let mut frames = self.frames.borrow_mut(); + if frames.is_full() { + Either::Left([region].into_iter()) + } else if let Some((stolen, left, mut right)) = try_steal_frames(region, 1) { + frames.push(stolen.start); + while !frames.is_full() { + if let Some((stolen, new_left, new_right)) = try_steal_frames(right, 1) { + assert!(new_left.is_empty()); + frames.push(stolen.start); + right = new_right; + } else { + break; + } + } + Either::Right([left, right].into_iter()) + } else { + Either::Left([region].into_iter()) + } + }) + } +} + +// TODO: document +#[derive(Debug)] +pub struct RecursiveFrameStealerAllocator<'a> { + inner: &'a RecursiveFrameStealer, +} + +impl<'a> RecursiveFrameStealerAllocator<'a> { + pub fn new(inner: &'a RecursiveFrameStealer) -> Self { + Self { inner } + } +} + +unsafe impl FrameAllocator for RecursiveFrameStealerAllocator<'_> { + fn allocate_frame(&mut self) -> Option { + self.inner.frames.borrow_mut().pop() + } +} + +/// A force-stealer: always steals a fixed range precisely +#[derive(Debug)] +pub struct ForceStealer { + region: PhysAddrRegion, +} + +impl ForceStealer { + pub fn new(region: PhysAddrRegion) -> Self { + Self { region } + } +} + +impl MemoryStealer for ForceStealer { + type Iter<'a, I: Iterator> = impl Iterator + where + Self: 'a; + + unsafe fn steal>( + &mut self, + regions: I, + ) -> Self::Iter<'_, I> { + regions.flat_map(|region| { + let mut valid = ArrayVec::<_, 2>::new(); + if region.start() < self.region.start() { + let end = region.end().min(self.region.start()); + valid.push(PhysAddrRegion::new(region.start(), end - region.start())); + } + if self.region.end() < region.end() { + let start = region.start().max(self.region.end()); + valid.push(PhysAddrRegion::new(start, region.end() - start)); + } + + valid.into_iter() + }) + } +} \ No newline at end of file diff --git a/kernel/src/serial.rs b/kernel/src/serial.rs index 5e9798c..719c3ca 100644 --- a/kernel/src/serial.rs +++ b/kernel/src/serial.rs @@ -24,7 +24,7 @@ pub fn _print(args: ::core::fmt::Arguments) { #[macro_export] macro_rules! serial_print { ($($arg:tt)*) => { - $crate::serial::_print(format_args!($($arg)*)); + $crate::serial::_print(format_args!($($arg)*)) }; } From 3c2988e713db9c9a376a73eb916a2e7f3cadf5b1 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 10 Dec 2023 18:58:36 -0500 Subject: [PATCH 33/36] Adding crypto stuff --- kernel/src/crypto/aes.md | 48 ++++++++++ kernel/src/crypto/ecdh.md | 45 +++++++++ kernel/src/crypto/mod.rs | 4 + kernel/src/crypto/random.md | 48 ++++++++++ kernel/src/crypto/request.md | 181 +++++++++++++++++++++++++++++++++++ 5 files changed, 326 insertions(+) create mode 100644 kernel/src/crypto/aes.md create mode 100644 kernel/src/crypto/ecdh.md create mode 100644 kernel/src/crypto/random.md create mode 100644 kernel/src/crypto/request.md diff --git a/kernel/src/crypto/aes.md b/kernel/src/crypto/aes.md new file mode 100644 index 0000000..be78f94 --- /dev/null +++ b/kernel/src/crypto/aes.md @@ -0,0 +1,48 @@ +use aes::Aes256; +use aes::cipher::{ + BlockEncrypt, BlockDecrypt, KeyInit, + generic_array::{GenericArray, typenum::U16}, +}; +use alloc::vec::Vec; + +const BLOCK_SIZE: usize = 16; + +pub struct ServerAES { + cipher: Aes256, +} + +impl ServerAES { + pub fn new(key: [u8; 32]) -> Self { + let key_arr = GenericArray::from(key); + let cipher = Aes256::new(&key_arr); + Self { + cipher, + } + } + + pub fn encrypt_msg(&self, msg: Vec) -> Vec { + // Create an empty vector to store the encrypted bytes + let mut encrypted_bytes = Vec::new(); + + // Loop over chunks of 16 bytes + for chunk in msg.chunks_exact(BLOCK_SIZE) { + let mut bytes: GenericArray<_, U16> = GenericArray::clone_from_slice(chunk); + self.cipher.encrypt_block(&mut bytes); + encrypted_bytes.extend_from_slice(bytes.as_slice()); + } + encrypted_bytes + } + + pub fn decrypt_msg(&self, msg: Vec) -> Vec { + // Create an empty vector to store the decrypted bytes + let mut decrypted_bytes = Vec::new(); + + // Loop over chunks of 16 bytes + for chunk in msg.chunks_exact(BLOCK_SIZE) { + let mut bytes: GenericArray<_, U16> = GenericArray::clone_from_slice(chunk); + self.cipher.decrypt_block(&mut bytes); + decrypted_bytes.extend_from_slice(bytes.as_slice()); + } + decrypted_bytes + } +} \ No newline at end of file diff --git a/kernel/src/crypto/ecdh.md b/kernel/src/crypto/ecdh.md new file mode 100644 index 0000000..c01cadb --- /dev/null +++ b/kernel/src/crypto/ecdh.md @@ -0,0 +1,45 @@ +use alloc::string::String; +use p256::{EncodedPoint, PublicKey, ecdh::EphemeralSecret, ecdh::SharedSecret}; +use rand_core::OsRng; +pub struct ECDH { + ephemeral_secret: EphemeralSecret, + public_key_obj: PublicKey, + public_key_hex: String, + shared_key: Option, +} + +impl ECDH { + pub fn new() -> Self { + let ephemeral_secret = EphemeralSecret::random(&mut OsRng); + let encoded = EncodedPoint::from(ephemeral_secret.public_key()); + let public_key_hex = hex::encode(encoded.as_ref()); + let public_key_obj = PublicKey::from_sec1_bytes(encoded.as_ref()).expect("Alice's public key invalid"); + Self { + ephemeral_secret, + public_key_obj, + public_key_hex, + shared_key: None, + } + } + pub fn generate_shared_key(&mut self, client_key: PublicKey) { + let shared_key = self.ephemeral_secret.diffie_hellman(&client_key); + self.shared_key = Some(shared_key); + } + + pub fn get_public_key_obj(&self) -> PublicKey{ + self.public_key_obj + } + + pub fn get_public_key_hex(&self) -> String { + self.public_key_hex.clone() + } + + pub fn get_shared_key_hex(&self) -> String { + hex::encode(self.shared_key.as_ref().expect("Failed to unwrap shared key").as_bytes()) + } +} + +pub fn create_pubkey_from_hex_string(hex: String) -> PublicKey { + let bytes = hex::decode(hex).expect("Failed to decode hexadecimal string"); + PublicKey::from_sec1_bytes(&bytes).expect("Failed to create PublicKey from hex") +} \ No newline at end of file diff --git a/kernel/src/crypto/mod.rs b/kernel/src/crypto/mod.rs index 61ee0f4..3e9c8fa 100644 --- a/kernel/src/crypto/mod.rs +++ b/kernel/src/crypto/mod.rs @@ -1 +1,5 @@ +// pub mod request; +// pub mod aes; +// pub mod ecdh; +// pub mod random; pub mod rng; diff --git a/kernel/src/crypto/random.md b/kernel/src/crypto/random.md new file mode 100644 index 0000000..e1b2691 --- /dev/null +++ b/kernel/src/crypto/random.md @@ -0,0 +1,48 @@ +use rand::RngCore; + +// Custom `getrandom` implementation that uses a simple PRNG (XorShift). +pub fn get_rand(buf: &mut [u8]) -> Result<(), getrandom::Error> { + let mut rng = XorShiftRng::new(); + rng.fill_bytes(buf); + Ok(()) +} + +// XorShiftRng: A simple pseudo-random number generator (PRNG). +struct XorShiftRng { + state: u32, +} + +impl XorShiftRng { + fn new() -> Self { + // Seed the PRNG with some initial value (e.g., current timestamp). + Self { state: 123456789 } + } +} + +impl RngCore for XorShiftRng { + fn next_u32(&mut self) -> u32 { + self.state ^= self.state << 13; + self.state ^= self.state >> 17; + self.state ^= self.state << 5; + self.state + } + + fn next_u64(&mut self) -> u64 { + // Combine two 32-bit outputs to form a 64-bit output. + ((self.next_u32() as u64) << 32) | (self.next_u32() as u64) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + // Use `slice::chunks_mut` to fill the destination buffer efficiently. + for chunk in dest.chunks_mut(4) { + let value = self.next_u32(); + chunk.copy_from_slice(&value.to_le_bytes()); + } + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { + // For simplicity, always succeed in filling the buffer. + self.fill_bytes(dest); + Ok(()) + } +} \ No newline at end of file diff --git a/kernel/src/crypto/request.md b/kernel/src/crypto/request.md new file mode 100644 index 0000000..d3a39e5 --- /dev/null +++ b/kernel/src/crypto/request.md @@ -0,0 +1,181 @@ +use alloc::{vec::Vec, string::String}; +use lazy_static::lazy_static; +use spin::Mutex; +use crate::{println, crypto::{ecdh::{ECDH, create_pubkey_from_hex_string}, aes::ServerAES}}; +use hex::FromHex; +use hashbrown::HashMap; +use rand_core::{RngCore, OsRng}; +const DEBUG: bool = true; + +lazy_static! { + static ref ECDH_KEY: Mutex = Mutex::new(ECDH::new()); + static ref AES_KEY: Mutex = Mutex::new(ServerAES::new([0u8; 32])); + static ref KEY_MAP: Mutex> = Mutex::new(HashMap::new()); +} + +pub fn handle_request(message: Vec) -> Vec { + println!("[USER] {:?}", message); + + let mut res = Vec::new(); + + if message.len() < 7 { + if DEBUG { + println!("Client's message too short to contain command"); + } + res = Vec::from("ERR: Command too short"); + } + else { + if message.starts_with(b"key_reg") { + if DEBUG { + println!("Recieved client public key: {}", (String::from_utf8(message.to_vec()).expect("Failed to stringify response"))); + } + let message: Vec = (&message[7..]).to_vec(); + // set up ecdh shared key and respond with server's public key + let mut ecdh = ECDH::new(); + ecdh.generate_shared_key(create_pubkey_from_hex_string(String::from_utf8(message.to_vec()).expect("failed to make string"))); + res = ecdh.get_public_key_hex().into_bytes().to_vec(); + + // set up aes cipher + let mut aes = ServerAES::new([0u8;32]); + // use shared key as [u8; 32] for the aes cipher + match <[u8; 32]>::from_hex(ecdh.get_shared_key_hex()) { + Ok(result) => aes = ServerAES::new(result), + Err(e) => println!("Conversion to [u8; 32] failed: {}", e), + } + + // use client public key as key and set values to the ecdh and aes structs + let mut key_map = KEY_MAP.lock(); + key_map.insert(String::from_utf8(message.to_vec()).expect("failed to convert to string"), (ecdh, aes)); + } + else { + // get ecdh and aes contexts from map + let unlocked_key_map = KEY_MAP.lock(); + // key will be the first 130 bytes + if message.len() < 137 { + res = Vec::from("ERR: Missing public key or command"); + } + else { + // then request will be in form publickey+enc(length+command+message) + let pubkey = (&message[..130]).to_vec(); + + if let Ok(key_str) = String::from_utf8(pubkey) { + if let Some((_ecdh, aes)) = unlocked_key_map.get(&key_str) { + + if DEBUG { + println!("found public key"); + } + + let decoded = decode_message((&message[130..]).to_vec(), aes); + + let command = (&decoded[..7]).to_vec(); + let message: Vec = (&decoded[7..]).to_vec(); + + // determine what to do based on command + if command.starts_with(b"runwasm") { + // call scheduler with bits + // let program_res = schedule_task(message) + // res = encode_message(program_res, aes); + res = encode_message(Vec::from("running task"), aes); + } + else if command.starts_with(b"ping_me") { + let mut return_message = Vec::from("Pong: "); + return_message.extend(message); + res = encode_message(return_message, aes); + } + else { + res = encode_message(Vec::from("ERR: Invalid command"), aes); + } + } else { + if DEBUG { + println!("Couldn't find Client's public key"); + } + res = Vec::from("ERR: Client's public key not found"); + } + } else { + res = Vec::from("ERR: Client's public key format is incorrect"); + } + } + } + } + res +} + +fn decode_message(message: Vec, aes: &ServerAES) -> Vec { + + if DEBUG { + println!("Decrypting message: {:?}", message); + } + + // decrypt the rest of the message + let request = aes.decrypt_msg(message); + + if DEBUG { + println!("Decrypted message: {:?}", request); + } + + // get the size of the unencrypted message + let message_size = bytes_to_u64(get_first_x_bytes(request.clone(), 8)); + + // get the actual message + let unencrypted_message = get_first_x_bytes(request[8..].to_vec(), message_size); + + if DEBUG { + println!("Unencrypted message: {:?}", unencrypted_message.clone()); + } + unencrypted_message +} + +fn encode_message(data_vec: Vec, aes_struct: &ServerAES) -> Vec { + let mut data_size = u64_to_bytes(data_vec.len() as u64); + + // combine the message with length of message before padding + data_size.extend(data_vec); + + // pad data to multiple of 16 with random bytes + pad_to_16(&mut data_size); + + // encrypt the data + if DEBUG { + println!("Encrypting message: {:?}", data_size); + } + let message = aes_struct.encrypt_msg(data_size); + + if DEBUG { + println!("Encrypted message: {:?}", message); + } + + // send the request to server and decrypt the response + if DEBUG { + println!("Sending request: {:?}", message); + } + + message +} + +fn u64_to_bytes(length: u64) -> Vec { + length.to_le_bytes().to_vec() +} + +fn bytes_to_u64(bytes: Vec) -> u64 { + let mut array = [0u8; 8]; + array.copy_from_slice(&bytes); + u64::from_le_bytes(array) +} + +fn get_first_x_bytes(data: Vec, x: u64) -> Vec { + let mut collected = Vec::new(); + for i in 0..x { + collected.push(data[i as usize]); + } + collected +} + +fn pad_to_16(data: &mut Vec) { + let remaining_bytes = 16 - (data.len() % 16); + + if remaining_bytes != 16 { + let mut rng = OsRng::default(); // Use OsRng to generate random bytes + let random_bytes: Vec = (0..remaining_bytes).map(|_| rng.next_u32() as u8).collect(); + data.extend(random_bytes); + } +} \ No newline at end of file From 409b175c7e404130189e43fea9457e5ed3fb7177 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 10 Dec 2023 19:02:11 -0500 Subject: [PATCH 34/36] Final bug fixes and oneshot --- kernel/src/allocator.rs | 2 +- kernel/src/main.rs | 5 +- kernel/src/network/rtl8139.rs | 4 +- kernel/src/network/socket.rs | 17 ++- kernel/src/task/mod.rs | 1 + kernel/src/task/tcp_echo.rs | 17 +-- kernel/src/task/test_reader.rs | 14 +- kernel/src/task/udp_echo.rs | 13 +- kernel/src/task/wasm_async.rs | 64 +++++++++ kernel/src/task/wasm_oneshot.rs | 229 ++++++++++++++++++++++++++++++++ wasm-demos/build.rs | 2 + wasm-demos/src/add.rs | 2 +- wasm-demos/src/complex.rs | 24 ++++ wasm-demos/src/sub.rs | 2 +- 14 files changed, 354 insertions(+), 42 deletions(-) create mode 100644 kernel/src/task/wasm_async.rs create mode 100644 kernel/src/task/wasm_oneshot.rs create mode 100644 wasm-demos/src/complex.rs diff --git a/kernel/src/allocator.rs b/kernel/src/allocator.rs index 192a403..2ca4722 100644 --- a/kernel/src/allocator.rs +++ b/kernel/src/allocator.rs @@ -7,7 +7,7 @@ pub mod linked_list; use fixed_size_block::FixedSizeBlockAllocator; pub const HEAP_START: *mut u8 = 0x_4444_4444_0000 as *mut u8; -pub const HEAP_SIZE: usize = 10 * 100 * 1024; +pub const HEAP_SIZE: usize = 10 * 1024 * 1024; #[global_allocator] // static ALLOCATOR: LockedHeap = LockedHeap::empty(); diff --git a/kernel/src/main.rs b/kernel/src/main.rs index 523ae9f..19cd0b2 100644 --- a/kernel/src/main.rs +++ b/kernel/src/main.rs @@ -4,7 +4,7 @@ #![test_runner(kernel::test_runner)] #![reexport_test_harness_main = "test_main"] -use kernel::{QemuExitCode, serial_println, network::test::{test_async, test_sync}}; +use kernel::{QemuExitCode, serial_println, network::test::{test_async, test_sync}, task::wasm_oneshot}; use bootloader_api::{ config::{BootloaderConfig, Mapping}, entry_point, BootInfo, @@ -109,7 +109,8 @@ fn kernel_main(boot_info: &'static mut BootInfo) -> ! { executor.spawn(Task::new(keyboard::print_keypresses())); executor.spawn(Task::new(udp_echo::udp_echo_server())); executor.spawn(Task::new(tcp_echo::tcp_echo_server())); - executor.spawn(Task::new(test_reader::test_reader_server())); + // executor.spawn(Task::new(test_reader::test_reader_server())); + executor.spawn(Task::new(wasm_oneshot::wasm_oneshot_server())); executor.run(); } diff --git a/kernel/src/network/rtl8139.rs b/kernel/src/network/rtl8139.rs index fbfe84c..21c0d0e 100644 --- a/kernel/src/network/rtl8139.rs +++ b/kernel/src/network/rtl8139.rs @@ -378,13 +378,11 @@ impl RTL8139 { true } - // todo: change to borrowed slice && make all enums derive debug - // ! ^^ /// Write the packet data to the card and notify /// This will send the packet /// - dest_ip only used to determine which interface to send on (local or network) /// Will return true if send on the network, false if local - pub fn send_packet(&self, packet_data: &Vec, dest_ip: u32) -> bool { + pub fn send_packet(&self, packet_data: &[u8], dest_ip: u32) -> bool { let local = NetworkQuery::parse_ip("127.0.0.1").unwrap(); let my_ip = self.ip_address.unwrap_or(local); if dest_ip == local || dest_ip == my_ip { diff --git a/kernel/src/network/socket.rs b/kernel/src/network/socket.rs index 0e93650..e66fbe8 100644 --- a/kernel/src/network/socket.rs +++ b/kernel/src/network/socket.rs @@ -571,6 +571,10 @@ impl Socket { } let pkt = pkt_or_err.unwrap(); if pkt.get_type() == LayerType::Tcp { + let mut data = pkt.unwrap_tcp().data; + if !data.is_empty() { + self.read_buffer.append(&mut data); + } // We got a TCP packet let mut tcp_session_guard = NET_INFO.config.lock(); // Check the acknowledgement to make sure everything is acked @@ -744,8 +748,15 @@ pub async fn test() -> Result<(), String> { receiver.close().await; server_socket.close().await; - // todo: what happens if we both write and then read... !! - // the polling of the writing could cause read to break! -- we are ignoring this for now... - + // Both writing and then both reading + let mut server_socket = Socket::open(SocketType::TCP, 60001, 5).await.unwrap(); + let mut t1 = Socket::connect(SocketType::TCP, my_ip, 60001, 60000, 5).await.unwrap(); + let mut t2 = server_socket.listen().await.unwrap().unwrap(); + t1.reliable_write(&mut vec![1, 2, 3, 4]).await; + t2.reliable_write(&mut vec![5, 6, 7, 8]).await; + let data1 = t1.read(4).await.unwrap(); + let data2 = t2.read(4).await.unwrap(); + check!(data1 == vec![5, 6, 7, 8], "Data1 is t2's message"); + check!(data2 == vec![1, 2, 3, 4], "Data2 is t1's message"); test_ok!(); } \ No newline at end of file diff --git a/kernel/src/task/mod.rs b/kernel/src/task/mod.rs index 6fe3ec7..112a2a2 100644 --- a/kernel/src/task/mod.rs +++ b/kernel/src/task/mod.rs @@ -11,6 +11,7 @@ pub mod timeout; pub mod udp_echo; pub mod test_reader; pub mod wasm_oneshot; +mod wasm_async; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct TaskId(u64); diff --git a/kernel/src/task/tcp_echo.rs b/kernel/src/task/tcp_echo.rs index 3bd4370..2c08804 100644 --- a/kernel/src/task/tcp_echo.rs +++ b/kernel/src/task/tcp_echo.rs @@ -29,8 +29,8 @@ pub async fn tcp_echo_server() { loop { // Continuously read from the socket // let (mut data, error) = socket.read(4).await; - let read_result = socket.read(4).await; // ! - if let Err(err) = read_result { // ! + let read_result = socket.read(1).await; + if let Err(err) = read_result { if err == NetworkErrors::Timeout { // If we timed-out, we just loop again reading the socket continue; @@ -40,7 +40,7 @@ pub async fn tcp_echo_server() { socket.close().await; break; } - let mut data = read_result.unwrap(); // ! + let mut data = read_result.unwrap(); if data.is_empty() { // continue if we didn't read any data continue; @@ -50,17 +50,6 @@ pub async fn tcp_echo_server() { if let Ok(message) = user_message { // Print it out print!("{}", message); - if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { - // If their message is quit or exit, we close the connection - let mut exit_message = "Closing socket...\n".as_bytes().to_vec(); - let is_err = socket.reliable_write(&mut exit_message).await; - if is_err.is_some() { - println!("[ERR] Writing error {:?}", is_err.unwrap()); - } - socket.close().await; - println!("Closed socket"); - break; - } } // Echo back the data from the socket let res_or_err = socket.reliable_write(&mut data).await; diff --git a/kernel/src/task/test_reader.rs b/kernel/src/task/test_reader.rs index 2addffe..210bf12 100644 --- a/kernel/src/task/test_reader.rs +++ b/kernel/src/task/test_reader.rs @@ -13,11 +13,11 @@ use md5::{Md5, Digest}; /// - WASM binary module #[allow(clippy::question_mark)] async fn get_wasm_file(mut stream: Socket) -> Result, NetworkErrors> { - let data = stream.read(4).await; // ! - if let Err(err) = data { // ! + let data = stream.read(4).await; + if let Err(err) = data { return Err(err); } - let len_bytes = data.unwrap(); // ! + let len_bytes = data.unwrap(); assert!(len_bytes.len() == 4); let mut bytes: [u8; 4] = [0; 4]; for (i, byte) in len_bytes.iter().enumerate() { @@ -25,14 +25,14 @@ async fn get_wasm_file(mut stream: Socket) -> Result, NetworkErrors> { } let file_size = u32::from_le_bytes(bytes); serial_println!("File is {} bytes", file_size); - let data = stream.read(file_size as usize).await; // ! - if let Err(err) = data { // ! + let data = stream.read(file_size as usize).await; + if let Err(err) = data { if err != NetworkErrors::ClosedSocket { return Err(err); } - return Err(NetworkErrors::FeatureNotAvailableYet) // ! -- this was the bug, if we read an error instead of data + return Err(NetworkErrors::FeatureNotAvailableYet); } - Ok(data.unwrap()) // ! + Ok(data.unwrap()) } pub async fn test_reader_server() { diff --git a/kernel/src/task/udp_echo.rs b/kernel/src/task/udp_echo.rs index b06d5bb..b9dc084 100644 --- a/kernel/src/task/udp_echo.rs +++ b/kernel/src/task/udp_echo.rs @@ -32,8 +32,8 @@ pub async fn udp_echo_server() { } loop { // Loop trying to read from the socket - let read_result = socket.read(4).await; // ! - if let Err(err) = read_result { // ! + let read_result = socket.read(1).await; + if let Err(err) = read_result { if err == NetworkErrors::Timeout { // If we timed-out, we just loop again reading the socket continue; @@ -43,19 +43,12 @@ pub async fn udp_echo_server() { socket.close().await; break; } - let mut data = read_result.unwrap(); // ! + let mut data = read_result.unwrap(); // Parse the data we read let user_message = String::from_utf8(data.clone()); if let Ok(message) = user_message { // If we can read the user message, print it print!("{}", message); - if message.to_lowercase().trim() == "quit" || message.to_lowercase().trim() == "exit" { - // If exit or quit, we close the socket - println!("Closing socket"); - let _ = socket.reliable_write(&mut ("Closing socket...\n".as_bytes().to_vec())).await; - socket.close().await; - return; - } } // Echo the data let res_or_err = socket.reliable_write(&mut data).await; diff --git a/kernel/src/task/wasm_async.rs b/kernel/src/task/wasm_async.rs new file mode 100644 index 0000000..e3caa47 --- /dev/null +++ b/kernel/src/task/wasm_async.rs @@ -0,0 +1,64 @@ +use core::future::Future; +use wasmi::{ + core::HostError, AsContextMut, Error, ResumableInvocation, TypedFunc, TypedResumableCall, + Value, WasmParams, WasmResults, +}; + +pub trait AsyncTrap: HostError { + type ValueSlice: AsRef<[Value]>; + type Fut<'a, C: AsContextMut>: Future>; + + fn handle_trap>( + &self, + ctx: C, + invocation: &ResumableInvocation, + ) -> Self::Fut<'_, C>; +} + +pub trait AsyncTypedResumableFunc { + type Fut, T, C: AsContextMut>: Future< + Output = Result, + >; + + fn call_resumable_async, T, C: AsContextMut>( + &self, + ctx: C, + params: Params, + ) -> Self::Fut; +} + +impl AsyncTypedResumableFunc + for TypedFunc +{ + type Fut, T, C: AsContextMut> = + impl Future>; + + fn call_resumable_async, T, C: AsContextMut>( + &self, + mut ctx: C, + params: Params, + ) -> Self::Fut { + let first_call = self + .call_resumable(ctx.as_context_mut(), params) + .map_err(Error::Trap); + async move { + let mut call_result = first_call?; + loop { + return match call_result { + TypedResumableCall::Finished(res) => Ok(res), + TypedResumableCall::Resumable(invocation) => { + // TODO: should this do something other than panic on downcast failure? + let host_err: &A = invocation + .host_error() + .downcast_ref() + .expect("Unknown host error"); + let new_params = host_err.handle_trap(&mut ctx, &invocation).await?; + call_result = + invocation.resume(ctx.as_context_mut(), new_params.as_ref())?; + continue; + } + }; + } + } + } +} \ No newline at end of file diff --git a/kernel/src/task/wasm_oneshot.rs b/kernel/src/task/wasm_oneshot.rs new file mode 100644 index 0000000..bd4339f --- /dev/null +++ b/kernel/src/task/wasm_oneshot.rs @@ -0,0 +1,229 @@ +use crate::{ + network::{socket::{Socket, SocketType}, errors::NetworkErrors}, + println, serial_println, +}; +use core::fmt; +use futures_util::Future; +use lazy_static::lazy_static; +use wasmi::{ + core::{HostError, Trap}, + AsContext, AsContextMut, Caller, Module, Store, Value, +}; + +use super::wasm_async::{AsyncTrap, AsyncTypedResumableFunc}; + +#[derive(Debug)] +struct OneShotState { + socket: Socket, + // instance: Instance, + memory: Option, +} + +#[derive(Debug)] +struct WasmOneshot(Store); + +#[derive(Debug)] +enum UnrecoverableSocketTrap { + NoExportedMemory, + OutOfBounds, +} + +impl fmt::Display for UnrecoverableSocketTrap { + // required for `HostError` + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl HostError for UnrecoverableSocketTrap {} + +#[derive(Debug)] +enum SocketTrap { + Read { buf: u32, len: u32 }, + Write { buf: u32, len: u32 }, +} + +impl fmt::Display for SocketTrap { + // required for `HostError` + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl HostError for SocketTrap {} + +impl AsyncTrap for SocketTrap { + type ValueSlice = [Value; 1]; + type Fut<'a, C: AsContextMut> = + impl Future>; + + fn handle_trap>( + &self, + mut ctx: C, + _invocation: &wasmi::ResumableInvocation, + ) -> Self::Fut<'_, C> { + async move { + let mut ctx_mut = ctx.as_context_mut(); + let memory_ref = ctx_mut.as_context_mut().data_mut().memory.unwrap(); + let (data, store) = memory_ref.data_and_store_mut(ctx_mut); + let socket = &mut store.socket; + match *self { + Self::Read { buf, len } => { + let Some(range) = data.get_mut(buf as usize..(buf + len) as usize) else { + return Err(Trap::from(UnrecoverableSocketTrap::OutOfBounds).into()); + }; + match socket.read(len as usize).await { + Ok(net_data) => { + range[..net_data.len()].copy_from_slice(&net_data); + Ok([Value::I32(net_data.len() as i32)]) + } + Err(net_err) => Ok([Value::I32(net_err as i32)]), + } + } + Self::Write { buf, len } => { + let Some(range) = data.get(buf as usize..(buf + len) as usize) else { + return Err(Trap::from(UnrecoverableSocketTrap::OutOfBounds).into()); + }; + let written = range.len(); + match socket.reliable_write(&mut range.to_vec()).await { + Some(net_err) => Ok([Value::I32(net_err as i32)]), + None => Ok([Value::I32(written as i32)]), + } + } + } + } + } +} + +impl OneShotState { + fn trap_read(buf: u32, len: u32) -> Result { + Err(SocketTrap::Read { buf, len }.into()) + } + + fn trap_write(buf: u32, len: u32) -> Result { + Err(SocketTrap::Write { buf, len }.into()) + } + + fn print(caller: Caller<'_, Self>, buf: u32, len: u32) { + match caller + .data() + .memory + .unwrap() + .data(caller.as_context()) + .get(buf as usize..(buf + len) as usize) + .and_then(|v| core::str::from_utf8(v).ok()) + { + Some(message) => crate::serial_println!("[WASM LOG]: {}", message), + None => crate::serial_println!("[WASM] Failed to decode log message"), + } + } +} + +const FUEL_LIMIT: u64 = 1000; + +lazy_static! { + static ref ENGINE: wasmi::Engine = { + let mut config = wasmi::Config::default(); + config.consume_fuel(true); + let engine = wasmi::Engine::new(&config); + + engine + }; + static ref LINKER: wasmi::Linker = { + let mut linker = wasmi::Linker::new(&ENGINE); + + linker + .func_wrap("env", "read", OneShotState::trap_read) + .unwrap(); + linker + .func_wrap("env", "write", OneShotState::trap_write) + .unwrap(); + linker + .func_wrap("env", "print", OneShotState::print) + .unwrap(); + linker + }; +} + +impl WasmOneshot { + async fn spawn(socket: Socket, module: &Module) -> Result<(), wasmi::Error> { + let mut store = Store::new( + &ENGINE, + OneShotState { + socket, + memory: None, + }, + ); + store.add_fuel(FUEL_LIMIT)?; + let instance = LINKER.instantiate(&mut store, module)?.start(&mut store)?; + store.data_mut().memory = Some( + instance + .get_memory(&mut store, "memory") + .ok_or_else(|| Trap::from(UnrecoverableSocketTrap::NoExportedMemory))?, + ); + let oneshot_main = instance.get_typed_func::<(), ()>(&mut store, "main")?; + + oneshot_main + .call_resumable_async::(&mut store, ()) + .await + } +} + + +/// WASM One Shot format: +/// - 32-bit little-endian header indicating module length +/// - WASM binary module +#[allow(clippy::question_mark)] +async fn run_wasm_file(mut stream: Socket) -> Result<(), NetworkErrors> { + let data = stream.read(4).await; + if let Err(err) = data { + return Err(err); + } + let len_bytes = data.unwrap(); + assert!(len_bytes.len() == 4); + let mut bytes: [u8; 4] = [0; 4]; + for (i, byte) in len_bytes.iter().enumerate() { + bytes[i] = *byte; + } + let file_size = u32::from_le_bytes(bytes); + serial_println!("File is {} bytes", file_size); + let data = stream.read(file_size as usize).await; + if let Err(err) = data { + if err != NetworkErrors::ClosedSocket { + return Err(err); + } + return Err(NetworkErrors::FeatureNotAvailableYet); + } + let wasm = data.unwrap(); + let wasm_module = Module::new(&ENGINE, wasm.as_slice()); + if let Ok(module) = wasm_module { + if let Err(err) = WasmOneshot::spawn(stream, &module).await { + serial_println!("[SPAWN] An error occurred {:?}", err); + } + } else if let Err(err) = wasm_module { + serial_println!("[INIT] An error occurred {:?}", err); + } + Ok(()) +} + + +pub async fn wasm_oneshot_server() { + let socket_or_err = Socket::open(SocketType::TCP, 7774, 0).await; + if let Err(err) = socket_or_err { + println!("[ERR] {:?}", err); + return; + } + let mut socket_gen = socket_or_err.unwrap(); + // Allow 10 sockets to form (we don't have threads so only listening to one at a time) + for _ in 0..10 { + // Listen for a new session socket to form + let socket = socket_gen.listen().await.unwrap().unwrap(); + println!("[WASM] Stream started."); + // Build a packet stream and process it + if let Err(err) = run_wasm_file(socket).await { + serial_println!("[NET] An error occurred {:?}", err); + } + println!("[WASM] Stream finished."); + } + println!("[INFO] Finished up all allocated sockets for echoing"); +} diff --git a/wasm-demos/build.rs b/wasm-demos/build.rs index 33ba3a8..98b549c 100644 --- a/wasm-demos/build.rs +++ b/wasm-demos/build.rs @@ -19,6 +19,8 @@ fn main() -> Result<(), Box> { let out_dir = env::var_os("OUT_DIR").unwrap(); let rustc = env::var_os("RUSTC").unwrap(); let includes = Path::new(&out_dir).join("includes.rs"); + eprintln!("Extract WASM from this directory: {:?}", out_dir); + let mut includes_f = File::create(includes)?; for template in fs::read_dir("src")? { diff --git a/wasm-demos/src/add.rs b/wasm-demos/src/add.rs index 73d1b20..4e600c7 100644 --- a/wasm-demos/src/add.rs +++ b/wasm-demos/src/add.rs @@ -1,4 +1,4 @@ #[no_mangle] -pub extern "C" fn add(left: u32, right: u32) -> u32 { +pub extern "C" fn main(left: u32, right: u32) -> u32 { left + right } diff --git a/wasm-demos/src/complex.rs b/wasm-demos/src/complex.rs new file mode 100644 index 0000000..df7eef3 --- /dev/null +++ b/wasm-demos/src/complex.rs @@ -0,0 +1,24 @@ +#[no_mangle] +pub extern "C" fn main(first: u32, second: u32) -> u32 { + // This is a more complicated compute task, for testing the soupOS instance + let mut k: u32 = 0; + // Generate a weird number k + for _ in 0..1000 { + k += first; + k *= second; + k %= 1000000007 + } + // Make k odd + if k % 2 == 0 { + k -= 1; + } + // Make k "13x + 7" + while k % 13 != 7 { + k -= second; + if second % 13 == 0 { + k -= 1; + } + } + // Solve for x + ((k - 7) / 13) as u32 +} diff --git a/wasm-demos/src/sub.rs b/wasm-demos/src/sub.rs index 9836a48..6808d1e 100644 --- a/wasm-demos/src/sub.rs +++ b/wasm-demos/src/sub.rs @@ -1,4 +1,4 @@ #[no_mangle] -pub extern "C" fn sub(left: u32, right: u32) -> u32 { +pub extern "C" fn main(left: u32, right: u32) -> u32 { left - right } From b25d97a2ee9d77fbdcf9bf81b866142d161f6323 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 10 Dec 2023 19:03:48 -0500 Subject: [PATCH 35/36] Gitignore stuff --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index ea8c4bf..4f2331b 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ /target +dump.pcap +Cargo.lock +.vscode/settings.json \ No newline at end of file From a4d0d449237bf9ff53574c7be3e5161655d597b3 Mon Sep 17 00:00:00 2001 From: EthanLavi <89608076+EthanLavi@users.noreply.github.com> Date: Sun, 10 Dec 2023 19:11:17 -0500 Subject: [PATCH 36/36] Adding prints --- Cargo.lock | 385 +++++++++++++++++++++++--------- kernel/src/task/wasm_oneshot.rs | 2 + 2 files changed, 281 insertions(+), 106 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d1edc6d..78814c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,15 +26,23 @@ version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + [[package]] name = "async-channel" -version = "1.9.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" +checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c" dependencies = [ "concurrent-queue", - "event-listener 2.5.3", + "event-listener 4.0.0", + "event-listener-strategy", "futures-core", + "pin-project-lite", ] [[package]] @@ -43,11 +51,11 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" dependencies = [ - "async-lock", + "async-lock 2.8.0", "autocfg", "cfg-if", "concurrent-queue", - "futures-lite", + "futures-lite 1.13.0", "log", "parking", "polling 2.8.0", @@ -59,22 +67,21 @@ dependencies = [ [[package]] name = "async-io" -version = "2.1.0" +version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10da8f3146014722c89e7859e1d7bb97873125d7346d10ca642ffab794355828" +checksum = "d6d3b15875ba253d1110c740755e246537483f152fa334f91abd7fe84c88b3ff" dependencies = [ - "async-lock", + "async-lock 3.2.0", "cfg-if", "concurrent-queue", "futures-io", - "futures-lite", + "futures-lite 2.1.0", "parking", - "polling 3.3.0", - "rustix 0.38.21", + "polling 3.3.1", + "rustix 0.38.28", "slab", "tracing", - "waker-fn", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -86,6 +93,17 @@ dependencies = [ "event-listener 2.5.3", ] +[[package]] +name = "async-lock" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" +dependencies = [ + "event-listener 4.0.0", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "async-process" version = "1.8.1" @@ -93,14 +111,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea6438ba0a08d81529c69b36700fa2f95837bfe3e776ab39cde9c14d9149da88" dependencies = [ "async-io 1.13.0", - "async-lock", + "async-lock 2.8.0", "async-signal", "blocking", "cfg-if", - "event-listener 3.0.1", - "futures-lite", - "rustix 0.38.21", - "windows-sys", + "event-listener 3.1.0", + "futures-lite 1.13.0", + "rustix 0.38.28", + "windows-sys 0.48.0", ] [[package]] @@ -109,16 +127,16 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e47d90f65a225c4527103a8d747001fc56e375203592b25ad103e1ca13124c5" dependencies = [ - "async-io 2.1.0", - "async-lock", + "async-io 2.2.1", + "async-lock 2.8.0", "atomic-waker", "cfg-if", "futures-core", "futures-io", - "rustix 0.38.21", + "rustix 0.38.28", "signal-hook-registry", "slab", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -178,18 +196,27 @@ dependencies = [ "wyz", ] +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "blocking" -version = "1.4.1" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c36a4d0d48574b3dd360b4b7d95cc651d2b6557b6402848a27d4b228a473e2a" +checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" dependencies = [ "async-channel", - "async-lock", + "async-lock 3.2.0", "async-task", "fastrand 2.0.1", "futures-io", - "futures-lite", + "futures-lite 2.1.0", "piper", "tracing", ] @@ -242,9 +269,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "concurrent-queue" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" dependencies = [ "crossbeam-utils", ] @@ -275,9 +302,9 @@ dependencies = [ [[package]] name = "crc-catalog" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cace84e55f07e7301bae1c519df89cdad8cc3cd868413d3fdbdeca9ff3db484" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] name = "crossbeam-queue" @@ -298,6 +325,26 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "downcast-rs" version = "1.2.0" @@ -306,12 +353,12 @@ checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" [[package]] name = "errno" -version = "0.3.5" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -322,15 +369,36 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "event-listener" -version = "3.0.1" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cec0252c2afff729ee6f00e903d479fba81784c8e2bd77447673471fdfaea1" +checksum = "d93877bcde0eb80ca09131a08d23f0a5c18a620b01db137dba666d18cd9b30c2" dependencies = [ "concurrent-queue", "parking", "pin-project-lite", ] +[[package]] +name = "event-listener" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +dependencies = [ + "event-listener 4.0.0", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -365,9 +433,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" dependencies = [ "futures-channel", "futures-core", @@ -380,9 +448,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", "futures-sink", @@ -403,15 +471,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" dependencies = [ "futures-core", "futures-task", @@ -439,11 +507,21 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-lite" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aeee267a1883f7ebef3700f262d2d54de95dfaf38189015a74fdc4e0c7ad8143" +dependencies = [ + "futures-core", + "pin-project-lite", +] + [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", @@ -458,15 +536,15 @@ checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-channel", "futures-core", @@ -480,11 +558,21 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "libc", @@ -505,9 +593,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.2" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash", "allocator-api2", @@ -542,19 +630,21 @@ checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "itoa" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "kernel" version = "0.1.0" dependencies = [ + "anyhow", + "arrayvec", "bootloader_api", "conquer-once", "crossbeam-queue", @@ -562,6 +652,7 @@ dependencies = [ "hashbrown", "lazy_static", "linked_list_allocator", + "md-5", "noto-sans-mono-bitmap", "pc-keyboard", "pci", @@ -584,9 +675,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.149" +version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "libm" @@ -611,9 +702,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.4.10" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "llvm-tools" @@ -623,9 +714,9 @@ checksum = "955be5d0ca0465caf127165acb47964f911e2bc26073e865deb8be7189302faf" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -650,6 +741,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "memchr" version = "2.6.4" @@ -673,9 +774,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "os" @@ -783,28 +884,28 @@ dependencies = [ "libc", "log", "pin-project-lite", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "polling" -version = "3.3.0" +version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53b6af1f60f36f8c2ac2aad5459d75a5a9b4be1e8cdd40264f315d78193e531" +checksum = "cf63fa624ab313c11656b4cda960bfc46c410187ad493c41f6ba2d8c1e991c9e" dependencies = [ "cfg-if", "concurrent-queue", "pin-project-lite", - "rustix 0.38.21", + "rustix 0.38.28", "tracing", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" dependencies = [ "unicode-ident", ] @@ -853,20 +954,20 @@ dependencies = [ "io-lifetimes", "libc", "linux-raw-sys 0.3.8", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "rustix" -version = "0.38.21" +version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ "bitflags 2.4.1", "errno", "libc", - "linux-raw-sys 0.4.10", - "windows-sys", + "linux-raw-sys 0.4.12", + "windows-sys 0.52.0", ] [[package]] @@ -877,9 +978,9 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "ryu" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "scopeguard" @@ -889,9 +990,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.190" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] @@ -907,9 +1008,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.190" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", @@ -918,9 +1019,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.107" +version = "1.0.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" dependencies = [ "itoa", "ryu", @@ -947,9 +1048,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "socket2" @@ -987,9 +1088,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.38" +version = "2.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" dependencies = [ "proc-macro2", "quote", @@ -1011,8 +1112,8 @@ dependencies = [ "cfg-if", "fastrand 2.0.1", "redox_syscall", - "rustix 0.38.21", - "windows-sys", + "rustix 0.38.28", + "windows-sys 0.48.0", ] [[package]] @@ -1051,6 +1152,12 @@ version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "uart_16550" version = "0.3.0" @@ -1070,9 +1177,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "uuid" -version = "1.5.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ "getrandom", ] @@ -1113,9 +1220,9 @@ version = "0.1.0" [[package]] name = "wasmi" -version = "0.31.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f341edb80021141d4ae6468cbeefc50798716a347d4085c3811900049ea8945" +checksum = "acfc1e384a36ca532d070a315925887247f3c7e23567e23e0ac9b1c5d6b8bf76" dependencies = [ "smallvec", "spin 0.9.8", @@ -1179,7 +1286,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", ] [[package]] @@ -1188,13 +1304,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -1203,42 +1334,84 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "wyz" version = "0.5.1" @@ -1273,18 +1446,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.20" +version = "0.7.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd66a62464e3ffd4e37bd09950c2b9dd6c4f8767380fabba0d523f9a775bc85a" +checksum = "306dca4455518f1f31635ec308b6b3e4eb1b11758cefafc782827d0aa7acb5c7" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.20" +version = "0.7.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "255c4596d41e6916ced49cfafea18727b24d67878fa180ddfd69b9df34fd1726" +checksum = "be912bf68235a88fbefd1b73415cb218405958d1655b2ece9035a19920bdf6ba" dependencies = [ "proc-macro2", "quote", diff --git a/kernel/src/task/wasm_oneshot.rs b/kernel/src/task/wasm_oneshot.rs index bd4339f..884e0ad 100644 --- a/kernel/src/task/wasm_oneshot.rs +++ b/kernel/src/task/wasm_oneshot.rs @@ -72,6 +72,7 @@ impl AsyncTrap for SocketTrap { let Some(range) = data.get_mut(buf as usize..(buf + len) as usize) else { return Err(Trap::from(UnrecoverableSocketTrap::OutOfBounds).into()); }; + serial_println!("[READING] {}", len); match socket.read(len as usize).await { Ok(net_data) => { range[..net_data.len()].copy_from_slice(&net_data); @@ -85,6 +86,7 @@ impl AsyncTrap for SocketTrap { return Err(Trap::from(UnrecoverableSocketTrap::OutOfBounds).into()); }; let written = range.len(); + serial_println!("[WRITING] {}", buf); match socket.reliable_write(&mut range.to_vec()).await { Some(net_err) => Ok([Value::I32(net_err as i32)]), None => Ok([Value::I32(written as i32)]),