Experimenting with more structured ways to handle command-line input/output in Rust
use crate::fluent;

use icu_locale::locale;
use quote::quote;
use syn::{parse_quote, parse_quote_spanned, spanned::Spanned};
use thiserror::Error;

mod attribute;
pub mod derive;
pub mod error;

#[derive(Debug, Error)]
pub enum UnsupportedReason {
    #[error("Unions are not supported")]
    Union,
    #[error("Unnamed fields are not supported")]
    UnnamedFields,
}

#[derive(Debug, Error)]
#[error("Unsupported Rust code")]
pub struct UnsupportedError {
    span: syn::Ident,
    reason: UnsupportedReason,
}

#[derive(Debug, Error)]
#[error("Failed to parse macro input")]
pub enum ParseError {
    #[error("invalid attribute")]
    InvalidAttribute(syn::Error),
    #[error("invalid item")]
    InvalidDeriveInput(syn::Error),
}

#[derive(Debug, Error)]
#[error(transparent)]
pub enum MacroError {
    Attribute(#[from] attribute::Error),
    Group(#[from] fluent::GroupError),
    Unsupported(#[from] UnsupportedError),
    ParseError(#[from] ParseError),
}

pub fn localize(
    attribute_stream: proc_macro2::TokenStream,
    derive_input_stream: proc_macro2::TokenStream,
) -> Result<proc_macro2::TokenStream, MacroError> {
    // Set up a global Miette report handler to ensure consistent non-Rustc diagnostics
    // If this returns an error, it just means the hook has already been set.
    let _result = miette::set_hook(Box::new(|_| {
        Box::new(
            miette::MietteHandlerOpts::new()
                // Force color output, even when printing using the debug formatter
                .color(true)
                .build(),
        )
    }));
    miette::set_panic_hook();

    // Parse the token streams
    let attribute: syn::LitStr = syn::parse2(attribute_stream)
        .map_err(|parse_error| ParseError::InvalidAttribute(parse_error))?;
    let derive_input: syn::DeriveInput = syn::parse2(derive_input_stream)
        .map_err(|parse_error| ParseError::InvalidDeriveInput(parse_error))?;

    let locales = attribute::locales(&attribute)?;

    // Keep track of all the Fluent files
    let tracked_paths = locales.clone();
    let tracked_paths = tracked_paths
        .values()
        .map(|path| path.to_string_lossy().to_string());
    let path_count = locales.len();

    // TODO: user-controlled canonical locale
    let group = fluent::Group::new(locale!("en-US"), locales)?;
    let canonical_locale = group.canonical_locale().id.clone().to_string();

    let available_locales = match &derive_input.data {
        syn::Data::Struct(_struct_data) => derive::locales_for_ident(&group, &derive_input.ident),
        syn::Data::Enum(enum_data) => derive::locales_for_enum(&group, &enum_data.variants),
        syn::Data::Union(_) => {
            return Err(MacroError::Unsupported(UnsupportedError {
                span: derive_input.ident.clone(),
                reason: UnsupportedReason::Union,
            }))
        }
    };

    let message_body = match &derive_input.data {
        syn::Data::Struct(struct_data) => {
            derive::message_for_struct(group, &derive_input.ident, &struct_data.fields)
        }
        syn::Data::Enum(enum_data) => derive::messages_for_enum(group, &enum_data.variants),
        syn::Data::Union(_) => {
            return Err(MacroError::Unsupported(UnsupportedError {
                span: derive_input.ident.clone(),
                reason: UnsupportedReason::Union,
            }))
        }
    }?;

    let ident = &derive_input.ident;

    // Get the original generics for the derived item
    let (initial_impl_generics, initial_type_generics, initial_where_clause) =
        derive_input.generics.split_for_impl();

    // Get the types of each named field
    let named_fields: Vec<&syn::Type> = match &derive_input.data {
        syn::Data::Struct(struct_data) => match &struct_data.fields {
            syn::Fields::Named(named_fields) => {
                named_fields.named.iter().map(|field| &field.ty).collect()
            }
            syn::Fields::Unit => Vec::new(),
            syn::Fields::Unnamed(_unnamed_fields) => {
                return Err(MacroError::Unsupported(UnsupportedError {
                    span: derive_input.ident.clone(),
                    reason: UnsupportedReason::UnnamedFields,
                }))
            }
        },
        syn::Data::Enum(enum_data) => enum_data
            .variants
            .iter()
            .map(|variant| match &variant.fields {
                syn::Fields::Named(named_fields) => {
                    Ok(named_fields.named.iter().map(|field| &field.ty).collect())
                }
                syn::Fields::Unit => Ok(Vec::new()),
                syn::Fields::Unnamed(_unnamed_fields) => {
                    Err(MacroError::Unsupported(UnsupportedError {
                        span: variant.ident.clone(),
                        reason: UnsupportedReason::UnnamedFields,
                    }))
                }
            })
            .collect::<Result<Vec<Vec<&syn::Type>>, _>>()?
            .into_iter()
            .flatten()
            .collect(),
        syn::Data::Union(_union_data) => todo!(),
    };

    // Add a bound on `Localize` for each field's type
    let mut generics = derive_input.generics.clone();
    let additional_bounds = named_fields
        .into_iter()
        .map(|field| -> syn::WherePredicate {
            // Attribute this bound to the original source code
            let span = field.span();
            parse_quote_spanned!(span=> #field: ::fluent_embed::Localize<W>)
        });
    generics
        .make_where_clause()
        .predicates
        .extend(additional_bounds);

    // Define a parameter of `std::io::Write` for `Localize`
    // e.g. for MyStruct<'a, T>, it will be <'a, T, W: std::io::Write>
    generics
        .params
        .push(syn::GenericParam::Type(parse_quote!(W: std::io::Write)));
    let (impl_generics, _type_generics, where_clause) = generics.split_for_impl();

    Ok(quote! {
        impl #initial_impl_generics #ident #initial_type_generics #initial_where_clause {
            // Call the `include_str!` macro to make sure the Fluent files are tracked
            // so when Fluent code changes, the generated code should be rebuild
            // TODO: This is a hack that should be replaced with https://github.com/rust-lang/rust/issues/99515 once stable
            const _TRACKED_PATHS: [&'static str; #path_count] = [#(include_str!(#tracked_paths)),*];
        }

        impl #impl_generics ::fluent_embed::Localize<W> for #ident #initial_type_generics #where_clause {
            const CANONICAL_LOCALE: ::fluent_embed::icu_locale::LanguageIdentifier =
                ::fluent_embed::icu_locale::langid!(#canonical_locale);

            fn available_locales(&self) -> Vec<::fluent_embed::icu_locale::LanguageIdentifier> {
                #available_locales
            }

            fn message_for_locale(
                &self,
                writer: &mut W,
                locale: &::fluent_embed::icu_locale::LanguageIdentifier,
            ) -> Result<(), ::fluent_embed::LocalizationError> {
                #message_body
            }
        }
    })
}