+ use core::hash::BuildHasher as _;
+ use core::hash::Hash as _;
+ use core::hash::Hasher as _;
+
+ use beancount_types::Acc;
+ use beancount_types::Account;
+ use beancount_types::Amount;
+ use beancount_types::Balance;
+ use beancount_types::Commodity;
+ use beancount_types::Directive;
+ use camino::Utf8Path;
+ use delegate::delegate;
+ use hashbrown::hash_map::RawEntryMut;
+ use hashbrown::HashMap;
+ use time::Date;
+
+ use crate::ImporterProtocol;
+
+ pub struct DeduplicateBalances<I, T, F>
+ where
+ I: ImporterProtocol,
+ T: Ord,
+ F: Fn(&Balance) -> &T,
+ {
+ inner: I,
+
+ key: F,
+ }
+
+ impl<I, T, F> DeduplicateBalances<I, T, F>
+ where
+ I: ImporterProtocol,
+ T: Ord,
+ F: Fn(&Balance) -> &T,
+ {
+ pub fn new(inner: I, key: F) -> Self {
+ Self { inner, key }
+ }
+ }
+
+ impl<I, T, F> DeduplicateBalances<I, T, F>
+ where
+ I: ImporterProtocol,
+ T: Ord,
+ F: Fn(&Balance) -> &T,
+ {
+ fn key<'b>(&self, balance: &'b Balance) -> &'b T {
+ (self.key)(balance)
+ }
+
+ fn upsert(&self, map: &mut HashMap<StorageKey, Balance>, balance: Balance) {
+ let Balance {
+ date,
+ ref account,
+ amount: Amount { ref commodity, .. },
+ ..
+ } = balance;
+
+ let query = QueryKey {
+ date,
+ account,
+ commodity,
+ };
+ let hash = {
+ let mut hasher = map.hasher().build_hasher();
+ query.hash(&mut hasher);
+ hasher.finish()
+ };
+
+ let entry = map
+ .raw_entry_mut()
+ .from_hash(hash, |storage| storage == query);
+
+ match entry {
+ RawEntryMut::Occupied(mut entry) => {
+ if self.key(entry.get()) < self.key(&balance) {
+ entry.insert(balance);
+ }
+ }
+ RawEntryMut::Vacant(entry) => {
+ entry.insert(query.into(), balance);
+ }
+ }
+ }
+ }
+
+ impl<I, T, F> ImporterProtocol for DeduplicateBalances<I, T, F>
+ where
+ F: Fn(&Balance) -> &T,
+ I: ImporterProtocol,
+ T: Ord,
+ {
+ type Error = I::Error;
+
+ delegate! {
+ to (self.inner) {
+ fn account(&self, file: &Utf8Path) -> Result<Account, Self::Error>;
+ fn date(&self, file: &Utf8Path) -> Option<Result<Date, Self::Error>>;
+ fn filename(&self, file: &Utf8Path) -> Option<Result<String, Self::Error>>;
+ fn identify(&self, file: &Utf8Path) -> Result<bool, Self::Error>;
+ fn name(&self) -> &'static str;
+ fn typetag_deserialize(&self);
+ }
+ }
+
+ fn extract(
+ &self,
+ file: &camino::Utf8Path,
+ existing: &[Directive],
+ ) -> Result<Vec<Directive>, Self::Error> {
+ let directives = self.inner.extract(file, existing)?;
+
+ let mut balances = HashMap::new();
+ let mut directives: Vec<_> = directives
+ .into_iter()
+ .filter_map(|directive| {
+ if let Directive::Balance(balance) = directive {
+ self.upsert(&mut balances, balance);
+ None
+ } else {
+ Some(directive)
+ }
+ })
+ .collect();
+
+ directives.extend(balances.into_values().map(Directive::from));
+
+ Ok(directives)
+ }
+ }
+
+ pub trait ImporterProtocolExt {
+ fn deduplicate_balances_by<T, F>(self, key: F) -> DeduplicateBalances<Self, T, F>
+ where
+ Self: ImporterProtocol + Sized,
+ T: Ord,
+ F: Fn(&Balance) -> &T,
+ {
+ DeduplicateBalances::new(self, key)
+ }
+ }
+
+ impl<I> ImporterProtocolExt for I where I: ImporterProtocol + Sized {}
+
+ #[derive(Debug, Eq, Hash, PartialEq)]
+ struct StorageKey {
+ date: Date,
+ account: Account,
+ commodity: Commodity,
+ }
+
+ impl From<QueryKey<'_>> for StorageKey {
+ fn from(query: QueryKey) -> Self {
+ let QueryKey {
+ date,
+ account,
+ commodity,
+ } = query;
+ let account = account.to_owned();
+ let commodity = *commodity;
+
+ Self {
+ date,
+ account,
+ commodity,
+ }
+ }
+ }
+
+ impl PartialEq<QueryKey<'_>> for &StorageKey {
+ fn eq(&self, other: &QueryKey) -> bool {
+ self.date == other.date
+ && self.account == other.account
+ && &self.commodity == other.commodity
+ }
+ }
+
+ #[derive(Debug, Hash)]
+ struct QueryKey<'q> {
+ date: Date,
+ account: &'q Acc,
+ commodity: &'q Commodity,
+ }