use super::{alloc_kernel_object_id, KernelObject, KernelObjectId, KernelObjectType};
use alloc::{
collections::VecDeque,
fmt,
sync::{Arc, Weak},
vec::Vec,
};
use poplar::syscall::{GetMessageError, SendMessageError, CHANNEL_MAX_NUM_HANDLES};
use spinning_top::Spinlock;
use tracing::warn;
#[derive(Debug)]
pub struct ChannelEnd {
pub id: KernelObjectId,
pub owner: KernelObjectId,
pub messages: Spinlock<VecDeque<Message>>,
other_end: Option<Weak<ChannelEnd>>,
}
impl ChannelEnd {
pub fn new_channel(owner: KernelObjectId) -> (Arc<ChannelEnd>, Arc<ChannelEnd>) {
let mut end_a = Arc::new(ChannelEnd {
id: alloc_kernel_object_id(),
owner,
messages: Spinlock::new(VecDeque::new()),
other_end: Some(Weak::default()),
});
let end_b = Arc::new(ChannelEnd {
id: alloc_kernel_object_id(),
owner,
messages: Spinlock::new(VecDeque::new()),
other_end: Some(Arc::downgrade(&end_a)),
});
unsafe {
Arc::get_mut_unchecked(&mut end_a).other_end = Some(Arc::downgrade(&end_b));
}
(end_a, end_b)
}
pub fn new_kernel_channel(owner: KernelObjectId) -> Arc<ChannelEnd> {
Arc::new(ChannelEnd {
id: alloc_kernel_object_id(),
owner,
messages: Spinlock::new(VecDeque::new()),
other_end: None,
})
}
pub fn add_message(&self, message: Message) {
self.messages.lock().push_back(message);
}
pub fn send(&self, message: Message) -> Result<(), SendMessageError> {
if let Some(ref other_end) = self.other_end {
match other_end.upgrade() {
Some(other_end) => {
other_end.add_message(message);
Ok(())
}
None => Err(SendMessageError::OtherEndDisconnected),
}
} else {
warn!("Discarding message sent down kernel channel");
Ok(())
}
}
pub fn receive<F, R>(&self, f: F) -> Result<R, GetMessageError>
where
F: FnOnce(Message) -> Result<R, (Message, GetMessageError)>,
{
let mut message_queue = self.messages.lock();
match f(message_queue.pop_front().ok_or(GetMessageError::NoMessage)?) {
Ok(value) => Ok(value),
Err((message, err)) => {
message_queue.push_front(message);
Err(err)
}
}
}
}
impl KernelObject for ChannelEnd {
fn id(&self) -> KernelObjectId {
self.id
}
fn typ(&self) -> KernelObjectType {
KernelObjectType::Channel
}
}
pub struct Message {
pub bytes: Vec<u8>,
pub handle_objects: [Option<Arc<dyn KernelObject>>; CHANNEL_MAX_NUM_HANDLES],
}
impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Message").field("bytes", &self.bytes).finish_non_exhaustive()
}
}
impl Message {
pub fn num_handles(&self) -> usize {
self.handle_objects.iter().fold(0, |n, ref handle| if handle.is_some() { n + 1 } else { n })
}
}