diff --git a/src/catalog.rs b/src/catalog.rs index f9672f2..e7f476b 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -184,7 +184,10 @@ pub struct Catalog { } impl Catalog { - pub fn populate(&mut self) -> anyhow::Result<()> { + pub fn populate_with_progress( + &mut self, + progress_bar: Option, + ) -> anyhow::Result<()> { let mut new_entries = BTreeMap::new(); let mut old_entries = Arc::new(std::mem::take(&mut self.entries)); @@ -218,14 +221,46 @@ impl Catalog { move |(_, relpath)| !old_entries.contains_key(relpath) }); let metadata = self.metadata.clone(); - for result in crate::parallel::for_each(iterator, move |res| { - res.and_then(|(entry, relative_path)| { - checksum_entry(metadata.algo, entry, relative_path) + + let results_iter = if let Some(ref bar) = progress_bar { + let discovery_bar = bar.clone(); + let progress_bar = progress_bar.clone(); + crate::parallel::for_each_with_discovery_callback( + iterator, + move |res| { + let result = res.and_then(|(entry, relative_path)| { + checksum_entry(metadata.algo, entry, relative_path) + }); + if let Some(ref bar) = progress_bar { + bar.notify_record_processed(result.as_ref().map(|(size, _, _)| *size).ok()); + } + result + }, + Some(Box::new(move || { + discovery_bar.notify_file_discovered(); + })), + ) + } else { + crate::parallel::for_each(iterator, move |res| { + res.and_then(|(entry, relative_path)| { + checksum_entry(metadata.algo, entry, relative_path) + }) }) - }) { - let (entry, relative_filename, signature) = result?; + }; + + let results_iter = if let Some(ref bar) = progress_bar { + let bar_clone = bar.clone(); + results_iter.with_total_callback(move |total| { + bar_clone.set_length(total); + }) + } else { + results_iter + }; + + for result in results_iter { + let (_, relative_filename, signature) = result?; let prev = new_entries.insert(relative_filename, signature); - assert!(prev.is_none(), "Entry {entry:?} was already in catalog!") + assert!(prev.is_none(), "Entry was already in catalog!") } assert!(self.entries.is_empty()); diff --git a/src/main.rs b/src/main.rs index 3705612..63fee65 100644 --- a/src/main.rs +++ b/src/main.rs @@ -156,7 +156,7 @@ fn load_and_verify_catalog( let catalog_filename = catalog.metadata().signature_file_path().clone(); let algo = catalog.metadata().algo(); - let bar = crate::progress::ProgressBar::new(catalog.len()); + let bar = crate::progress::ProgressBar::new(Some(catalog.len())); let mut report = crate::parallel::for_each(catalog.into_iter(), move |entry| { let res = entry @@ -194,7 +194,9 @@ fn create_catalog(params: SignParams, config: &config::Config) -> anyhow::Result let mut catalog = directory.empty_catalog(algo); - catalog.populate()?; + let bar = crate::progress::ProgressBar::new(None); + + catalog.populate_with_progress(Some(bar))?; catalog.write_signature_file(false) } diff --git a/src/parallel.rs b/src/parallel.rs index 11c68d5..ab56b42 100644 --- a/src/parallel.rs +++ b/src/parallel.rs @@ -1,25 +1,46 @@ +use std::sync::Arc; use std::thread::JoinHandle; -pub(crate) fn for_each(iterator: I, handler: F) -> impl Iterator +pub(crate) fn for_each(iterator: I, handler: F) -> ResultsIterator where I: Iterator + Send + 'static, R: Send + 'static, F: Fn(T) -> R + Send + Clone + 'static, T: Send + 'static, { - let (entries_sender, entries_receiver) = crossbeam_channel::bounded(num_cpus::get() * 4); + for_each_with_discovery_callback(iterator, handler, None) +} +pub(crate) fn for_each_with_discovery_callback( + iterator: I, + handler: F, + discovery_callback: Option>, +) -> ResultsIterator +where + I: Iterator + Send + 'static, + R: Send + 'static, + F: Fn(T) -> R + Send + Clone + 'static, + T: Send + 'static, +{ + let (entries_sender, entries_receiver) = crossbeam_channel::unbounded(); let (results_sender, results_receiver) = crossbeam_channel::unbounded(); + let (total_sender, total_receiver) = crossbeam_channel::bounded(1); + + let discovery_callback = discovery_callback.map(Arc::new); let producer = std::thread::spawn(move || { let mut size = 0; for entry in iterator { size += 1; + if let Some(ref callback) = discovery_callback { + callback(); + } if entries_sender.send(entry).is_err() { log::debug!("Entries sender channel closed. Closing producer"); break; } } + let _ = total_sender.send(size); size }); @@ -37,33 +58,51 @@ where } drop(results_sender); - ResultsIterator::new(producer, results_receiver) + ResultsIterator::new(producer, results_receiver, total_receiver) } pub struct ResultsIterator { - join_handle: Option>, total: Option, receiver: crossbeam_channel::Receiver, + total_receiver: crossbeam_channel::Receiver, received: usize, + on_total_discovered: Option>, } impl ResultsIterator { - fn new(join_handle: JoinHandle, receiver: crossbeam_channel::Receiver) -> Self { + fn new( + _join_handle: JoinHandle, + receiver: crossbeam_channel::Receiver, + total_receiver: crossbeam_channel::Receiver, + ) -> Self { Self { - join_handle: Some(join_handle), total: None, receiver, + total_receiver, received: 0, + on_total_discovered: None, } } + + pub fn with_total_callback(mut self, callback: F) -> Self + where + F: FnOnce(usize) + Send + 'static, + { + self.on_total_discovered = Some(Box::new(callback)); + self + } } impl Iterator for ResultsIterator { type Item = R; fn next(&mut self) -> Option { - if let Some(handle) = self.join_handle.take() { - if handle.is_finished() { - self.total.replace(handle.join().unwrap()); + // Check if we received the total count + if self.total.is_none() { + if let Ok(total) = self.total_receiver.try_recv() { + self.total.replace(total); + if let Some(callback) = self.on_total_discovered.take() { + callback(total); + } } } @@ -73,8 +112,34 @@ impl Iterator for ResultsIterator { } } - // we either didn't finish, or we are still missing our total - self.receiver.recv().ok() + // Get the next result + match self.receiver.recv().ok() { + Some(item) => { + self.received += 1; + // Check again for total after receiving an item + if self.total.is_none() { + if let Ok(total) = self.total_receiver.try_recv() { + self.total.replace(total); + if let Some(callback) = self.on_total_discovered.take() { + callback(total); + } + } + } + Some(item) + } + None => { + // No more results, but check one last time for total + if self.total.is_none() { + if let Ok(total) = self.total_receiver.try_recv() { + self.total.replace(total); + if let Some(callback) = self.on_total_discovered.take() { + callback(total); + } + } + } + None + } + } } } diff --git a/src/progress.rs b/src/progress.rs index b61f698..386145f 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -8,21 +8,38 @@ const SIZE_UPDATE_FREQ: std::time::Duration = std::time::Duration::from_secs(3); pub struct ProgressBar { bar: Arc, size: Arc, + discovered_count: Arc, } impl ProgressBar { - pub fn new(len: usize) -> Self { - let bar = indicatif::ProgressBar::new(len.try_into().unwrap()); - bar.set_draw_target(ProgressDrawTarget::stderr_with_hz(5)); - bar.set_style( - ProgressStyle::with_template( - "[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}", - ) - .unwrap(), - ); + pub fn new(len: Option) -> Self { + let bar = match len { + Some(length) => { + let bar = indicatif::ProgressBar::new(length.try_into().unwrap()); + bar.set_style( + ProgressStyle::with_template( + "[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}", + ) + .unwrap(), + ); + bar + } + None => { + let bar = indicatif::ProgressBar::new_spinner(); + bar.set_style( + ProgressStyle::with_template( + "[{elapsed_precise}] {spinner:.cyan/blue} {pos:>7}/{prefix}+ {msg}", + ) + .unwrap(), + ); + bar + } + }; + bar.set_draw_target(ProgressDrawTarget::stderr_with_hz(5)); let bar = Arc::new(bar); let size = Arc::new(AtomicU64::default()); + let discovered_count = Arc::new(AtomicU64::default()); let bar_weak = Arc::downgrade(&bar); let size_weak = Arc::downgrade(&size); @@ -52,7 +69,30 @@ impl ProgressBar { } }); - Self { bar, size } + Self { + bar, + size, + discovered_count, + } + } + + pub fn set_length(&self, len: usize) { + self.bar.set_length(len.try_into().unwrap()); + self.bar.set_style( + ProgressStyle::with_template( + "[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}", + ) + .unwrap(), + ); + } + + pub fn notify_file_discovered(&self) { + self.discovered_count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let count = self + .discovered_count + .load(std::sync::atomic::Ordering::Relaxed); + self.bar.set_prefix(count.to_string()); } pub fn notify_record_processed(&self, record_size: Option) {