From 221f75e9764780ae02cf7062d2424e47154a00f5 Mon Sep 17 00:00:00 2001 From: Andreas Tsouchlos Date: Wed, 31 Dec 2025 18:21:38 +0200 Subject: [PATCH] Modify procedural macro to allow more flexible function signatures; Refactor code --- macros/src/lib.rs | 70 ++++++++++++++++------------ src/app.rs | 68 ++++++++------------------- src/app/run.rs | 97 +++++++++++++++++++++++++++++++++++++++ src/app/seeding.rs | 23 ++-------- src/app/snowballing.rs | 37 ++------------- src/crossterm.rs | 101 ++--------------------------------------- 6 files changed, 173 insertions(+), 223 deletions(-) create mode 100644 src/app/run.rs diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 07e1287..fe8d1f1 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -2,7 +2,7 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; -use syn::{FnArg, ImplItem, ItemImpl, Pat, Type, parse_macro_input}; +use syn::{FnArg, ImplItem, ItemImpl, ReturnType, Type, parse_macro_input}; #[proc_macro_attribute] pub fn component(attr: TokenStream, item: TokenStream) -> TokenStream { @@ -23,14 +23,12 @@ pub fn component(attr: TokenStream, item: TokenStream) -> TokenStream { for item in &mut input.items { if let ImplItem::Fn(method) = item { - // Check for #[action] let has_action = method .attrs .iter() .any(|attr| attr.path().is_ident("action")); if has_action { - // Remove the #[action] attribute so it doesn't cause compilation errors method.attrs.retain(|attr| !attr.path().is_ident("action")); let method_name = &method.sig.ident; @@ -39,44 +37,61 @@ pub fn component(attr: TokenStream, item: TokenStream) -> TokenStream { snake_to_pascal(&method_name.to_string()) ); + let returns_result = match &method.sig.output { + ReturnType::Default => false, + ReturnType::Type(_, ty) => { + quote!(#ty).to_string().contains("Result") + } + }; + let mut variant_fields = Vec::new(); let mut call_args = Vec::new(); + let mut variant_arg_names = Vec::new(); + // Inspect arguments to decide what goes into the Enum and what is injected for arg in method.sig.inputs.iter().skip(1) { // Skip 'self' if let FnArg::Typed(pat_type) = arg { if is_sender_type(&pat_type.ty) { + // Inject the 'tx' from handle_action, don't add to Enum call_args.push(quote!(tx)); - continue; - } + } else { + // This is a data argument, add to Enum + let ty = &pat_type.ty; + variant_fields.push(quote!(#ty)); - let ty = &pat_type.ty; - variant_fields.push(quote!(#ty)); - - if let Pat::Ident(pi) = &*pat_type.pat { - let arg_ident = &pi.ident; - call_args.push(quote!(#arg_ident)); + let arg_id = format_ident!( + "arg_{}", + variant_arg_names.len() + ); + variant_arg_names.push(arg_id.clone()); + call_args.push(quote!(#arg_id)); } } } + // Generate Enum Variant if variant_fields.is_empty() { enum_variants.push(quote!(#variant_name)); - match_arms.push(quote! { - #enum_name::#variant_name => self.#method_name(tx) - }); } else { enum_variants .push(quote!(#variant_name(#(#variant_fields),*))); - // Extracting bindings for the match arm: Enum::Var(a, b) -> self.method(a, b, tx) - let pattern_args = (0..variant_fields.len()) - .map(|i| format_ident!("arg_{}", i)); - let pattern_args_clone = pattern_args.clone(); - - match_arms.push(quote! { - #enum_name::#variant_name(#(#pattern_args),*) => self.#method_name(#(#pattern_args_clone,)* tx) - }); } + + // Generate Match Arm + let pattern = if variant_arg_names.is_empty() { + quote!(#enum_name::#variant_name) + } else { + quote!(#enum_name::#variant_name(#(#variant_arg_names),*)) + }; + + let call = if returns_result { + quote!(self.#method_name(#(#call_args),*)) + } else { + quote!({ self.#method_name(#(#call_args),*); ::core::result::Result::Ok(()) }) + }; + + match_arms.push(quote!(#pattern => #call)); } } } @@ -105,6 +120,11 @@ pub fn component(attr: TokenStream, item: TokenStream) -> TokenStream { TokenStream::from(expanded) } +fn is_sender_type(ty: &Type) -> bool { + let s = quote!(#ty).to_string(); + s.contains("UnboundedSender") +} + fn snake_to_pascal(s: &str) -> String { s.split('_') .map(|word| { @@ -118,9 +138,3 @@ fn snake_to_pascal(s: &str) -> String { }) .collect() } - -// Helper to detect if an argument is the Sender to inject it from handle_action -fn is_sender_type(ty: &Type) -> bool { - let s = quote!(#ty).to_string(); - s.contains("UnboundedSender") -} diff --git a/src/app.rs b/src/app.rs index c49db3d..162447b 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,4 +1,5 @@ pub mod common; +pub mod run; pub mod seeding; pub mod snowballing; @@ -13,8 +14,8 @@ use tokio::{ time::sleep, }; -use crate::crossterm::Action; use crate::{ + app::run::Action, literature::{Publication, SnowballingHistory, get_publication_by_id}, status_error, status_info, }; @@ -76,25 +77,16 @@ pub struct App { #[component(GlobalAction)] impl App { #[action] - fn quit( - &mut self, - _: &'static UnboundedSender, - ) -> Result<(), SendError> { + fn quit(&mut self) { self.should_quit = true; - Ok(()) } #[action] - fn next_tab( - &mut self, - _: &'static UnboundedSender, - ) -> Result<(), SendError> { + fn next_tab(&mut self) { self.state.current_tab = match self.state.current_tab { Tab::Seeding => Tab::Snowballing, Tab::Snowballing => Tab::Seeding, }; - - Ok(()) } // TODO: Have status messages always last the same amount of time @@ -103,7 +95,7 @@ impl App { &mut self, msg: StatusMessage, action_tx: &'static UnboundedSender, - ) -> Result<(), SendError> { + ) { match &msg { StatusMessage::Error(_) => { error!("Status message: {:?}", msg) @@ -126,34 +118,21 @@ impl App { error!("{}", err); } }); - - Ok(()) } #[action] - fn clear_stat_msg( - &mut self, - _: &'static UnboundedSender, - ) -> Result<(), SendError> { + fn clear_stat_msg(&mut self) { self.state.status_message = StatusMessage::Info("".to_string()); - - Ok(()) } // TODO: Is deduplication necessary here? #[action] - fn add_included_pub( - &mut self, - publ: Publication, - _: &'static UnboundedSender, - ) -> Result<(), SendError> { + fn add_included_pub(&mut self, publ: Publication) { self.state .history .current_iteration .included_publications .push(publ.clone()); - - Ok(()) } #[action] @@ -219,10 +198,7 @@ impl App { &self, action_tx: &'static UnboundedSender, ) -> Result<(), SendError> { - status_info!( - action_tx, - "Fetch action triggered" - ) + status_info!(action_tx, "Fetch action triggered") } pub fn handle_key( @@ -231,40 +207,36 @@ impl App { action_tx: &'static UnboundedSender, ) -> Result<(), SendError> { match (self.state.current_tab, key) { - (_, KeyCode::Esc) => action_tx.send(GlobalAction::Quit.into())?, - (_, KeyCode::Tab) => { - action_tx.send(GlobalAction::NextTab.into())? - } + (_, KeyCode::Esc) => action_tx.send(GlobalAction::Quit.into()), + (_, KeyCode::Tab) => action_tx.send(GlobalAction::NextTab.into()), (Tab::Seeding, KeyCode::Char(c)) => { - action_tx.send(SeedingAction::EnterChar(c).into())?; + action_tx.send(SeedingAction::EnterChar(c).into()) } (Tab::Seeding, KeyCode::Backspace) => { - action_tx.send(SeedingAction::EnterBackspace.into())?; + action_tx.send(SeedingAction::EnterBackspace.into()) } (Tab::Seeding, KeyCode::Enter) => { - action_tx.send(GlobalAction::FetchPub.into())?; + action_tx.send(GlobalAction::FetchPub.into()) } (Tab::Snowballing, KeyCode::Enter) => { - action_tx.send(SnowballingAction::Search.into())?; + action_tx.send(SnowballingAction::Search.into()) } (Tab::Snowballing, KeyCode::Char('h')) => { - action_tx.send(SnowballingAction::SelectLeftPane.into())?; + action_tx.send(SnowballingAction::SelectLeftPane.into()) } (Tab::Snowballing, KeyCode::Char('l')) => { - action_tx.send(SnowballingAction::SelectRightPane.into())?; + action_tx.send(SnowballingAction::SelectRightPane.into()) } (Tab::Snowballing, KeyCode::Char('j')) => { - action_tx.send(SnowballingAction::NextItem.into())?; + action_tx.send(SnowballingAction::NextItem.into()) } (Tab::Snowballing, KeyCode::Char('k')) => { - action_tx.send(SnowballingAction::PrevItem.into())?; + action_tx.send(SnowballingAction::PrevItem.into()) } (Tab::Snowballing, KeyCode::Char(' ')) => { - action_tx.send(GlobalAction::Fetch.into())?; + action_tx.send(GlobalAction::Fetch.into()) } - _ => {} + _ => Ok(()), } - - Ok(()) } } diff --git a/src/app/run.rs b/src/app/run.rs new file mode 100644 index 0000000..53fee01 --- /dev/null +++ b/src/app/run.rs @@ -0,0 +1,97 @@ +use std::{error::Error, time::Duration}; + +use crossterm::event::{self, Event}; +use ratatui::{Terminal, prelude::Backend}; +use static_cell::StaticCell; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; + +use crate::{ + app::{ + App, AppState, GlobalAction, common::Component, seeding::SeedingAction, + snowballing::SnowballingAction, + }, + ui, +}; + +static ACTION_QUEUE_TX: StaticCell> = StaticCell::new(); +static ACTION_QUEUE_RX: StaticCell> = + StaticCell::new(); + +// TODO: Move this somewhere sensible +#[derive(Clone, Debug)] +pub enum Action { + Snowballing(SnowballingAction), + Seeding(SeedingAction), + Global(GlobalAction), +} + +impl From for Action { + fn from(action: GlobalAction) -> Self { + Action::Global(action) + } +} + +impl From for Action { + fn from(action: SnowballingAction) -> Self { + Action::Snowballing(action) + } +} + +impl From for Action { + fn from(action: SeedingAction) -> Self { + Action::Seeding(action) + } +} + +// TODO: Is there a way to completely decouple this from crossterm? +pub async fn run_app( + terminal: &mut Terminal, + app_state: AppState, +) -> Result> +where + ::Error: 'static, +{ + let (action_tx, action_rx): ( + UnboundedSender, + UnboundedReceiver, + ) = mpsc::unbounded_channel(); + + let action_tx_ref = ACTION_QUEUE_TX.init(action_tx); + let action_rx_ref = ACTION_QUEUE_RX.init(action_rx); + + let mut app = App { + state: app_state, + should_quit: false, + }; + + loop { + app.state.refresh_component_states(); // TODO: Is it a problem to call this every frame? + terminal.draw(|frame| ui::draw(frame, &mut app.state))?; + + if event::poll(Duration::from_millis(100))? { + if let Event::Key(key) = event::read()? { + app.handle_key(key.code, action_tx_ref)?; + } + } + + while let Ok(action) = action_rx_ref.try_recv() { + match action { + Action::Seeding(seeding_action) => app + .state + .seeding + .handle_action(seeding_action, action_tx_ref), + Action::Snowballing(snowballing_action) => app + .state + .snowballing + .handle_action(snowballing_action, action_tx_ref), + Action::Global(global_action) => { + app.handle_action(global_action, action_tx_ref) + } + }?; + } + + if app.should_quit { + return Ok(app.state); + } + } +} diff --git a/src/app/seeding.rs b/src/app/seeding.rs index ef06d84..1163e99 100644 --- a/src/app/seeding.rs +++ b/src/app/seeding.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc::{UnboundedSender, error::SendError}; use crate::literature::Publication; use brittling_macros::component; @@ -14,31 +13,19 @@ pub struct SeedingComponent { #[component(SeedingAction)] impl SeedingComponent { #[action] - pub fn clear_input( - &mut self, - _: &UnboundedSender, - ) -> Result<(), SendError> { - Ok(self.input.clear()) + pub fn clear_input(&mut self) { + self.input.clear() } #[action] - pub fn enter_char( - &mut self, - c: char, - _: &UnboundedSender, - ) -> Result<(), SendError> { - Ok(self.input.push(c)) + pub fn enter_char(&mut self, c: char) { + self.input.push(c) } #[action] - pub fn enter_backspace( - &mut self, - _: &UnboundedSender, - ) -> Result<(), SendError> { + pub fn enter_backspace(&mut self) { if self.input.len() > 0 { self.input.truncate(self.input.len() - 1); } - - Ok(()) } } diff --git a/src/app/snowballing.rs b/src/app/snowballing.rs index d28079c..583f66a 100644 --- a/src/app/snowballing.rs +++ b/src/app/snowballing.rs @@ -1,8 +1,6 @@ use brittling_macros::component; use ratatui::widgets::ListState; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::mpsc::error::SendError; use crate::literature::Publication; @@ -101,38 +99,25 @@ impl SnowballingComponent { } #[action] - fn select_left_pane( - &mut self, - _: &UnboundedSender, - ) -> Result<(), SendError> { + fn select_left_pane(&mut self) { self.active_pane = ActivePane::IncludedPublications; if let None = self.included_list_state.selected() { self.included_list_state.select(Some(0)); } - - Ok(()) } #[action] - fn select_right_pane( - &mut self, - _: &UnboundedSender, - ) -> Result<(), SendError> { + fn select_right_pane(&mut self) { self.active_pane = ActivePane::PendingPublications; if let None = self.pending_list_state.selected() { self.pending_list_state.select(Some(0)); } - - Ok(()) } #[action] - fn search( - &self, - _: &UnboundedSender, - ) -> Result<(), SendError> { + fn search(&self) { match self.active_pane { ActivePane::IncludedPublications => { if let Some(idx) = self.included_list_state.selected() { @@ -145,15 +130,10 @@ impl SnowballingComponent { } } } - - Ok(()) } #[action] - fn next_item( - &mut self, - _: &UnboundedSender, - ) -> Result<(), SendError> { + fn next_item(&mut self) { match self.active_pane { ActivePane::IncludedPublications => { Self::select_next_item_impl( @@ -168,15 +148,10 @@ impl SnowballingComponent { ); } } - - Ok(()) } #[action] - fn prev_item( - &mut self, - _: &UnboundedSender, - ) -> Result<(), SendError> { + fn prev_item(&mut self) { match self.active_pane { ActivePane::IncludedPublications => { Self::select_prev_item_impl( @@ -191,7 +166,5 @@ impl SnowballingComponent { ); } } - - Ok(()) } } diff --git a/src/crossterm.rs b/src/crossterm.rs index 3d8ae06..cc282b7 100644 --- a/src/crossterm.rs +++ b/src/crossterm.rs @@ -1,11 +1,11 @@ -use std::{error::Error, io, time::Duration}; +use std::{error::Error, io}; use log::error; use ratatui::{ Terminal, - backend::{Backend, CrosstermBackend}, + backend::CrosstermBackend, crossterm::{ - event::{self, DisableMouseCapture, EnableMouseCapture, Event}, + event::{DisableMouseCapture, EnableMouseCapture}, execute, terminal::{ EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, @@ -13,16 +13,8 @@ use ratatui::{ }, }, }; -use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; -use crate::{ - app::{App, AppState, common::Component}, - ui, -}; - -use crate::app::GlobalAction; -use crate::app::seeding::SeedingAction; -use crate::app::snowballing::SnowballingAction; +use crate::app::{AppState, run::run_app}; pub async fn run(app_state: AppState) -> Result> { // setup terminal @@ -51,88 +43,3 @@ pub async fn run(app_state: AppState) -> Result> { Ok(app_result?) } - -use static_cell::StaticCell; - -static ACTION_QUEUE_TX: StaticCell> = StaticCell::new(); -static ACTION_QUEUE_RX: StaticCell> = - StaticCell::new(); - -// TODO: Move this somewhere sensible -#[derive(Clone, Debug)] -pub enum Action { - Snowballing(SnowballingAction), - Seeding(SeedingAction), - Global(GlobalAction), -} - -impl From for Action { - fn from(action: GlobalAction) -> Self { - Action::Global(action) - } -} - -impl From for Action { - fn from(action: SnowballingAction) -> Self { - Action::Snowballing(action) - } -} - -impl From for Action { - fn from(action: SeedingAction) -> Self { - Action::Seeding(action) - } -} - -async fn run_app( - terminal: &mut Terminal, - app_state: AppState, -) -> Result> -where - ::Error: 'static, -{ - let (action_tx, action_rx): ( - UnboundedSender, - UnboundedReceiver, - ) = mpsc::unbounded_channel(); - - let action_tx_ref = ACTION_QUEUE_TX.init(action_tx); - let action_rx_ref = ACTION_QUEUE_RX.init(action_rx); - - let mut app = App { - state: app_state, - should_quit: false, - }; - - loop { - app.state.refresh_component_states(); // TODO: Is it a problem to call this every frame? - terminal.draw(|frame| ui::draw(frame, &mut app.state))?; - - if event::poll(Duration::from_millis(100))? { - if let Event::Key(key) = event::read()? { - app.handle_key(key.code, action_tx_ref)?; - } - } - - while let Ok(action) = action_rx_ref.try_recv() { - // TODO: Handle errors - match action { - Action::Seeding(seeding_action) => app - .state - .seeding - .handle_action(seeding_action, action_tx_ref), - Action::Snowballing(snowballing_action) => app - .state - .snowballing - .handle_action(snowballing_action, action_tx_ref), - Action::Global(global_action) => { - app.handle_action(global_action, action_tx_ref) - } - }; - } - - if app.should_quit { - return Ok(app.state); - } - } -}