diff --git a/kernel/src/objects/endpoint.rs b/kernel/src/objects/endpoint.rs new file mode 100644 index 0000000..bb0c525 --- /dev/null +++ b/kernel/src/objects/endpoint.rs @@ -0,0 +1,211 @@ +use crate::{objects::*, plat::trap::TrapContextOps}; +use core::cmp::min; +use tcb::{ThreadState, EP_QUEUE_LINK_ID}; +use uapi::{cap::ObjectType, fault::Fault, syscall::*}; +use utils::{addr::*, array::*, linked_list::Link}; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum EndpointState { + Idle, + Sending, + Receiving, +} + +#[repr(C)] +pub struct EndpointObject { + pub queue: Link, + pub state: EndpointState, +} + +impl KernelObject for EndpointObject { + const OBJ_TYPE: ObjectType = ObjectType::Endpoint; +} + +/* + * in vanilla seL4, EndpointCap layout: + * > args[0]: | cap_tag | can_grant_reply | can_grant | can_recv | can_send | base_ptr | + * > | [63:59] | [58:58] | [57:57] | [56:56] | [55:55] | [54:0] | + * > args[1]: | badge | + * > | [63::0] | + * + * in our implementation, EndpointCap layout: + * > args[0]: badge + * > args[1]: none + * > ptr: base_ptr + * > cap_type: cap_tag + */ + +pub type EndpointCap<'a> = Cap<'a, EndpointObject>; + +impl<'a> EndpointCap<'a> { + pub fn mint(ptr: PhysAddr, badge: usize) -> RawCap { + RawCap::new(badge, 0, ptr, ObjectType::Endpoint) + } + + pub fn can_send(&self) -> bool { + self.cte.cap.args[0] & 0x2 != 0 + } + + pub fn can_recv(&self) -> bool { + self.cte.cap.args[0] & 0x1 != 0 + } + + pub fn badge(&self) -> usize { + self.cte.cap.args[1] + } + + pub fn state(&self) -> EndpointState { + self.as_object().state + } + + pub fn set_state(&mut self, state: EndpointState) { + self.as_object_mut().state = state; + } + + pub fn do_send(&mut self, send: &mut TcbObject) -> SysResult<()> { + match self.state() { + EndpointState::Idle | EndpointState::Sending => { + send.set_state(ThreadState::Sending); + send.set_badge(self.badge()); + + self.set_state(EndpointState::Sending); + self.as_object_mut().queue.append(send); + }, + EndpointState::Receiving => { + assert!(!self.as_object().queue.is_empty(), "Receive endpoint queue must not be empty"); + + let recv = { + let recv = self.as_object_mut().queue.next_mut().unwrap(); + recv.ep_queue.detach(); + + // SAFETY: `recv` is detached from the queue, no longer related to `self` + unsafe { &mut *(recv as *mut TcbObject) } + }; + + if self.as_object().queue.is_empty() { + self.set_state(EndpointState::Idle); + } + + do_ipc_transfer(send, send.get_message_info()?, recv, recv.get_message_info()?, self.badge())?; + + recv.set_state(ThreadState::Running); + recv.schedule_next(); + }, + }; + + Ok(()) + } + + pub fn do_recv(&mut self, recv: &mut TcbObject) -> SysResult<()> { + match self.state() { + EndpointState::Idle | EndpointState::Receiving => { + recv.set_state(ThreadState::Receiving); + + self.set_state(EndpointState::Receiving); + self.as_object_mut().queue.append(recv); + }, + EndpointState::Sending => { + assert!(!self.as_object().queue.is_empty(), "Send endpoint queue must not be empty"); + + let send = { + let send = self.as_object_mut().queue.next_mut().unwrap(); + send.ep_queue.detach(); + + // SAFETY: `send` is detached from the queue, no longer related to `self` + unsafe { &mut *(send as *mut TcbObject) } + }; + + if self.as_object().queue.is_empty() { + self.set_state(EndpointState::Idle); + } + + do_ipc_transfer(send, send.get_message_info()?, recv, recv.get_message_info()?, self.badge())?; + + send.set_state(ThreadState::Running); + send.schedule_next(); + }, + }; + + Ok(()) + } +} + +fn copy_message(send: &mut TcbObject, send_msg: MessageInfo, recv: &mut TcbObject, recv_msg: MessageInfo) -> SysResult<()> { + let send_len = send_msg.length() - if send_msg.transfer_cap() { 2 } else { 0 }; + let recv_len = recv_msg.length() - if recv_msg.transfer_cap() { 2 } else { 0 }; + let msg_len = min(send_len, recv_len); + + // get ipc buffer + let mut recv_cap = recv.buffer()?; + let recv_buf: &mut [usize] = cast_arr_mut(recv_cap.as_object_mut()); + + let send_cap = send.buffer()?; + let send_buf: &[usize] = cast_arr(send_cap.as_object()); + + // copy ipc buffer + let bgn = REG_ARG_MAX + 1; + recv_buf[..(msg_len - bgn)].copy_from_slice(&send_buf[..(msg_len - bgn)]); + + // copy register + for i in REG_ARG_0..min(REG_ARG_MAX + 1, msg_len) { + recv.trapframe.set_reg(i, send.trapframe.get_reg(i)); + } + + Ok(()) +} + +fn ipc_get_args(tcb: &TcbObject, idx: usize) -> SysResult { + let ret = if idx <= REG_ARG_MAX { + tcb.trapframe.get_reg(idx) + } else { + let buf = tcb.buffer()?; + let buf: &[usize] = cast_arr(buf.as_object()); + buf[idx] + }; + + Ok(ret) +} + +fn transfer_cap(send: &mut TcbObject, send_msg: MessageInfo, recv: &mut TcbObject, recv_msg: MessageInfo) -> SysResult<()> { + if !send_msg.transfer_cap() || !recv_msg.transfer_cap() { + return Ok(()); + } + + let send_cspace = send.cspace()?; + let recv_cspace = recv.cspace()?; + + let send_cptr = ipc_get_args(send, send_msg.length() - LEN_ARG_OFFSET_CPTR)?; + let send_bits = ipc_get_args(send, send_msg.length() - LEN_ARG_OFFSET_BITS)?; + + let recv_cptr = ipc_get_args(recv, recv_msg.length() - LEN_ARG_OFFSET_CPTR)?; + let recv_bits = ipc_get_args(recv, recv_msg.length() - LEN_ARG_OFFSET_BITS)?; + + let send_cap = send_cspace.resolve_address_bits(send_cptr, send_bits)?; + let recv_cap = recv_cspace.resolve_address_bits(recv_cptr, recv_bits)?; + + let dest = NullCap::try_from(recv_cap)?; + dest.override_cap(send_cap.cap); + + unsafe { send_cap.cap.update(|cap| *cap = NullCap::mint()) }; + + Ok(()) +} + +fn do_ipc_transfer( + send: &mut TcbObject, + send_msg: MessageInfo, + recv: &mut TcbObject, + recv_msg: MessageInfo, + _badge: usize, +) -> SysResult<()> { + if send.fault() == Fault::Null { + // do normal transfer + copy_message(send, send_msg, recv, recv_msg)?; + transfer_cap(send, send_msg, recv, recv_msg)?; + + todo!("do_ipc_transfer: normal transfer: uapi/reply"); + } else { + // do fault transfer + todo!("do_ipc_transfer: fault transfer"); + } +} diff --git a/kernel/src/objects/mod.rs b/kernel/src/objects/mod.rs index 5f2718c..9fea7bf 100644 --- a/kernel/src/objects/mod.rs +++ b/kernel/src/objects/mod.rs @@ -24,6 +24,7 @@ use utils::addr::AddressOps; pub mod cap; pub mod cnode; +pub mod endpoint; pub mod frame; pub mod null; pub mod table; @@ -32,6 +33,7 @@ pub mod untyped; pub use cap::{CapEntry, RawCap}; pub use cnode::{CNodeCap, CNodeObject}; +pub use endpoint::{EndpointCap, EndpointObject}; pub use frame::{FrameCap, FrameObject}; pub use null::NullCap; pub use table::{TableCap, TableObject};