diff --git a/espanso-ipc/src/lib.rs b/espanso-ipc/src/lib.rs index 1abbb4f..9ba72a9 100644 --- a/espanso-ipc/src/lib.rs +++ b/espanso-ipc/src/lib.rs @@ -18,9 +18,8 @@ */ use anyhow::Result; -use crossbeam::channel::{unbounded, Receiver}; use serde::{de::DeserializeOwned, Serialize}; -use std::path::Path; +use std::{path::Path}; use thiserror::Error; #[cfg(target_os = "windows")] @@ -29,13 +28,22 @@ pub mod windows; #[cfg(not(target_os = "windows"))] pub mod unix; +pub type EventHandler = Box EventHandlerResponse>; + +pub enum EventHandlerResponse { + NoResponse, + Response(Event), + Error(anyhow::Error), + Exit, +} + pub trait IPCServer { - fn run(&self) -> Result<()>; - fn accept_one(&self) -> Result<()>; + fn run(self, handler: EventHandler) -> Result<()>; } pub trait IPCClient { - fn send(&self, event: Event) -> Result<()>; + fn send_sync(&mut self, event: Event) -> Result; + fn send_async(&mut self, event: Event) -> Result<()>; } #[cfg(not(target_os = "windows"))] @@ -55,17 +63,16 @@ pub fn client(id: &str, parent_dir: &Path) -> Result( +pub fn server( id: &str, _: &Path, -) -> Result<(impl IPCServer, Receiver)> { - let (sender, receiver) = unbounded(); - let server = windows::WinIPCServer::new(id, sender)?; - Ok((server, receiver)) +) -> Result> { + let server = windows::WinIPCServer::new(id)?; + Ok(server) } #[cfg(target_os = "windows")] -pub fn client(id: &str, _: &Path) -> Result> { +pub fn client(id: &str, _: &Path) -> Result> { let client = windows::WinIPCClient::new(id)?; Ok(client) } @@ -73,45 +80,184 @@ pub fn client(id: &str, _: &Path) -> Result("testespansoipc", &std::env::temp_dir()).unwrap(); - let server_handle = std::thread::spawn(move || { - server.accept_one().unwrap(); + fn ipc_async_message() { + let server = server::("testespansoipcasync", &std::env::temp_dir()).unwrap(); + + let client_handle = std::thread::spawn(move || { + let mut client = client::("testespansoipcasync", &std::env::temp_dir()).unwrap(); + + client.send_async(Event::Async).unwrap(); + client.send_async(Event::ExitRequest).unwrap(); }); - // TODO: avoid delay and change the IPC code so that we can wait for the IPC - std::thread::sleep(std::time::Duration::from_millis(300)); + server + .run(Box::new(move |event| match event { + Event::ExitRequest => EventHandlerResponse::Exit, + evt => { + assert!(matches!(evt, Event::Async)); + EventHandlerResponse::NoResponse + } + })) + .unwrap(); - let client = client::("testespansoipc", &std::env::temp_dir()).unwrap(); - client.send(Event::Foo("hello".to_string())).unwrap(); - - let event = receiver.recv().unwrap(); - assert!(matches!(event, Event::Foo(x) if x == "hello")); - - server_handle.join().unwrap(); + client_handle.join().unwrap(); } #[test] - fn ipc_client_fails_to_send() { - let client = client::("testespansoipc", &std::env::temp_dir()).unwrap(); - assert!(client.send(Event::Foo("hello".to_string())).is_err()); + fn ipc_sync_message() { + let server = server::("testespansoipcsync", &std::env::temp_dir()).unwrap(); + + let client_handle = std::thread::spawn(move || { + let mut client = client::("testespansoipcsync", &std::env::temp_dir()).unwrap(); + + let response = client.send_sync(Event::Sync("test".to_owned())).unwrap(); + client.send_async(Event::ExitRequest).unwrap(); + + assert!(matches!(response, Event::SyncResult(s) if s == "test")); + }); + + server + .run(Box::new(move |event| match event { + Event::ExitRequest => EventHandlerResponse::Exit, + Event::Sync(s) => EventHandlerResponse::Response(Event::SyncResult(s)), + _ => EventHandlerResponse::NoResponse, + })) + .unwrap(); + + client_handle.join().unwrap(); + } + + #[test] + fn ipc_multiple_sync_with_delay_message() { + let server = server::("testespansoipcmultiplesync", &std::env::temp_dir()).unwrap(); + + let client_handle = std::thread::spawn(move || { + let mut client = client::("testespansoipcmultiplesync", &std::env::temp_dir()).unwrap(); + + let response = client.send_sync(Event::Sync("test".to_owned())).unwrap(); + + std::thread::sleep(std::time::Duration::from_millis(500)); + + let response2 = client.send_sync(Event::Sync("test2".to_owned())).unwrap(); + client.send_async(Event::ExitRequest).unwrap(); + + assert!(matches!(response, Event::SyncResult(s) if s == "test")); + assert!(matches!(response2, Event::SyncResult(s) if s == "test2")); + }); + + server + .run(Box::new(move |event| match event { + Event::ExitRequest => EventHandlerResponse::Exit, + Event::Sync(s) => EventHandlerResponse::Response(Event::SyncResult(s)), + _ => EventHandlerResponse::NoResponse, + })) + .unwrap(); + + client_handle.join().unwrap(); + } + + #[test] + fn ipc_multiple_clients() { + let server = server::("testespansoipcmultiple", &std::env::temp_dir()).unwrap(); + + let (tx, rx) = channel(); + + let client_handle = std::thread::spawn(move || { + let mut client = client::("testespansoipcmultiple", &std::env::temp_dir()).unwrap(); + + let response = client.send_sync(Event::Sync("client1".to_owned())).unwrap(); + + tx.send(()).unwrap(); + + assert!(matches!(response, Event::SyncResult(s) if s == "client1")); + }); + + let client_handle2 = std::thread::spawn(move || { + let mut client = client::("testespansoipcmultiple", &std::env::temp_dir()).unwrap(); + + let response = client.send_sync(Event::Sync("client2".to_owned())).unwrap(); + + // Wait for the other client before terminating + rx.recv().unwrap(); + + client.send_async(Event::ExitRequest).unwrap(); + + assert!(matches!(response, Event::SyncResult(s) if s == "client2")); + }); + + server + .run(Box::new(move |event| match event { + Event::ExitRequest => EventHandlerResponse::Exit, + Event::Sync(s) => EventHandlerResponse::Response(Event::SyncResult(s)), + _ => EventHandlerResponse::NoResponse, + })) + .unwrap(); + + client_handle.join().unwrap(); + client_handle2.join().unwrap(); + } + + #[test] + fn ipc_sync_big_payload_message() { + let server = server::("testespansoipcsyncbig", &std::env::temp_dir()).unwrap(); + + let client_handle = std::thread::spawn(move || { + let mut client = client::("testespansoipcsyncbig", &std::env::temp_dir()).unwrap(); + + let mut payload = String::new(); + for _ in 0..10000 { + payload.push_str("log string repeated"); + } + let response = client.send_sync(Event::Sync(payload.clone())).unwrap(); + client.send_async(Event::ExitRequest).unwrap(); + + assert!(matches!(response, Event::SyncResult(s) if s == payload)); + }); + + server + .run(Box::new(move |event| match event { + Event::ExitRequest => EventHandlerResponse::Exit, + Event::Sync(s) => EventHandlerResponse::Response(Event::SyncResult(s)), + _ => EventHandlerResponse::NoResponse, + })) + .unwrap(); + + client_handle.join().unwrap(); } } diff --git a/espanso-ipc/src/windows.rs b/espanso-ipc/src/windows.rs index 6670743..c27e04c 100644 --- a/espanso-ipc/src/windows.rs +++ b/espanso-ipc/src/windows.rs @@ -18,95 +18,154 @@ */ use anyhow::Result; -use crossbeam::channel::Sender; use log::{error, info}; -use named_pipe::{PipeClient, PipeOptions}; +use named_pipe::{ConnectingServer, PipeClient, PipeOptions}; use serde::{de::DeserializeOwned, Serialize}; -use std::io::{BufReader, Read, Write}; +use std::{io::{Write}}; -use crate::{IPCClient, IPCServer, IPCServerError}; +use crate::{ + EventHandler, EventHandlerResponse, IPCClient, IPCClientError, IPCServer, +}; -const CLIENT_TIMEOUT: u32 = 2000; +const DEFAULT_CLIENT_TIMEOUT: u32 = 2000; -pub struct WinIPCServer { - options: PipeOptions, - sender: Sender, +pub struct WinIPCServer { + server: Option, } -impl WinIPCServer { - pub fn new(id: &str, sender: Sender) -> Result { +impl WinIPCServer { + pub fn new(id: &str) -> Result { let pipe_name = format!("\\\\.\\pipe\\{}", id); let options = PipeOptions::new(&pipe_name); + let server = Some(options.single()?); info!("binded to named pipe: {}", pipe_name); - Ok(Self { options, sender }) + Ok(Self { server }) } } -impl IPCServer for WinIPCServer { - fn run(&self) -> anyhow::Result<()> { +impl IPCServer for WinIPCServer { + fn run(mut self, handler: EventHandler) -> anyhow::Result<()> { + let server = self + .server + .take() + .expect("unable to extract IPC server handle"); + let mut stream = server.wait()?; + loop { - self.accept_one()?; - } - } - - fn accept_one(&self) -> Result<()> { - let server = self.options.single()?; - let connection = server.wait(); - - match connection { - Ok(stream) => { - let mut json_str = String::new(); - let mut buf_reader = BufReader::new(stream); - let result = buf_reader.read_to_string(&mut json_str); - - match result { - Ok(_) => { - let event: Result = serde_json::from_str(&json_str); + // Read multiple commands from the client + loop { + match read_line(&mut stream) { + Ok(Some(line)) => { + let event: Result = serde_json::from_str(&line); match event { - Ok(event) => { - if self.sender.send(event).is_err() { - return Err(IPCServerError::SendFailed().into()); + Ok(event) => match handler(event) { + EventHandlerResponse::Response(response) => { + let mut json_event = serde_json::to_string(&response)?; + json_event.push('\n'); + stream.write_all(json_event.as_bytes())?; + stream.flush()?; } - } + EventHandlerResponse::NoResponse => { + // Async event, no need to reply + } + EventHandlerResponse::Error(err) => { + error!("ipc handler reported an error: {}", err); + } + EventHandlerResponse::Exit => { + return Ok(()); + } + }, Err(error) => { error!("received malformed event from ipc stream: {}", error); + break; } } } + Ok(None) => { + // EOF reached + break; + } Err(error) => { error!("error reading ipc stream: {}", error); + break; } } } - Err(err) => { - return Err(IPCServerError::StreamEnded(err).into()); - } - }; - Ok(()) + stream = stream.disconnect()?.wait()?; + } + } +} + +// Unbuffered version, necessary to concurrently write +// to the buffer if necessary (when receiving sync messages) +fn read_line(stream: R) -> Result> { + let mut buffer = Vec::new(); + + let mut is_eof = true; + + for byte_res in stream.bytes() { + let byte = byte_res?; + + if byte == 10 { + // Newline + break; + } else { + buffer.push(byte); + } + + is_eof = false; + } + + if is_eof { + Ok(None) + } else { + Ok(Some(String::from_utf8(buffer)?)) } } pub struct WinIPCClient { - pipe_name: String, + stream: PipeClient, } impl WinIPCClient { pub fn new(id: &str) -> Result { let pipe_name = format!("\\\\.\\pipe\\{}", id); - Ok(Self { pipe_name }) + + let stream = PipeClient::connect_ms(&pipe_name, DEFAULT_CLIENT_TIMEOUT)?; + Ok(Self { stream }) } } -impl IPCClient for WinIPCClient { - fn send(&self, event: Event) -> Result<()> { - let mut stream = PipeClient::connect_ms(&self.pipe_name, CLIENT_TIMEOUT)?; +impl IPCClient for WinIPCClient { + fn send_sync(&mut self, event: Event) -> Result { + { + let mut json_event = serde_json::to_string(&event)?; + json_event.push('\n'); + self.stream.write_all(json_event.as_bytes())?; + self.stream.flush()?; + } - let json_event = serde_json::to_string(&event)?; - stream.write_all(json_event.as_bytes())?; + // Read the response + if let Some(line) = read_line(&mut self.stream)? { + let event: Result = serde_json::from_str(&line); + match event { + Ok(response) => Ok(response), + Err(err) => Err(IPCClientError::MalformedResponse(err.into()).into()), + } + } else { + Err(IPCClientError::EmptyResponse.into()) + } + } + + fn send_async(&mut self, event: Event) -> Result<()> { + let mut json_event = serde_json::to_string(&event)?; + json_event.push('\n'); + self.stream.write_all(json_event.as_bytes())?; + self.stream.flush()?; Ok(()) }