From mathcomp Require Import all_ssreflect ssralg.
Require Import isomorphism axioms.

Class Zero T := { zero : T }.

(* The nth-order infinitesimals. *)

Definition nilpotent n := {x : R | x ^+ n = 0}.
Definition nilpotent_union := {x : R | exists n, x ^+ n = 0}.

Notation "''D_' n" := (nilpotent (n.+1))
    (at level 8, n at level 2, format "''D_' n").

Notation "''D_-1'" := (nilpotent 0)
    (at level 8, format "''D_-1'").

Notation "''D_∞'" := nilpotent_union
    (at level 8, format "''D_∞'").

Coercion R_of_nilpotent_union (d : nilpotent_union) : R := proj1_sig d.

Coercion union_of_nilpotent {n} : nilpotent n -> 'D_∞.
    move=> [x pf].
    exists x.
    exists n.
    exact pf.
Defined.


Fact higher_power_still_zero {x : R} {m n} (leq : m <= n)
    : x ^+ m = 0 -> x ^+ n = 0.
Proof.
    rewrite -(subnK leq) exprD => ->.
    apply: mulr0.
Qed.

Definition widen_nilpotent {m n} (leq : m <= n) : nilpotent m -> nilpotent n := 
    fun '(exist x pf) => (exist _ x (higher_power_still_zero leq pf)).

#[refine]
Global Instance D_zero {n} : Zero 'D_n := {
    zero := exist _ 0 _
}.
apply: expr0n.
Defined.

Global Instance nilpotent_union_zero : Zero nilpotent_union := {
    zero := (zero : 'D_0)
}.



Definition d_minus_void_iso : 'D_-1 <--> void.
Proof.
    have: 'D_-1 -> void.
    {
        move=>[x pf].
        rewrite expr0 in pf.
        move: (oner_neq0 R).
        by move /eqP.
    }
    move=> fwd.

    refine ({|
        Forward := fwd;
        Backward := of_void _;
        BF x := match fwd x with end;
        FB x := match x with end;
    |}).
Defined.

Definition polynomial {n} (coeffs : 'I_n -> R) (x : R) : R
    := \prod_(i < n) ((x ^+ i) * coeffs i).

(* Implementation at the bottom of the file. *)
Module Type nilpotent_properties_sig.
    Axiom KL_D : forall {n}, (nilpotent n -> R) <--> ('I_n -> R).
    Axiom KL_D_back : forall {n coeffs} {d : nilpotent n},
        Backward _ _ KL_D coeffs d = polynomial coeffs d.
End nilpotent_properties_sig.










(* Spaces of infinitesimal vectors. *)

(* I think this space acts like a walking simplex? *)
Definition Δ n := {x : 'I_n -> R | forall i j, x i * x j = 0}.
Coercion Rn_of_Δ {n} (d : Δ n) : {ffun 'I_n -> R} := finfun (proj1_sig d).

#[refine]
Global Instance Δ_zero {n} : Zero (Δ n) := {
    zero := exist _ (fun=> 0) (fun _ _ => _)
}.
exact (mul0r 0).
Defined.

Lemma eq_Δ {n} {u v : Δ n} (pf : forall i, u i = v i) : u = v.
    move: u v pf => [u pf_u] [v pf_v] pf.
    rewrite /Rn_of_Δ in pf; simpl in pf.

    have: u = v.
    {
        apply: functional_extensionality => i.
        move: (pf i).
        by rewrite ffunE ffunE.
    }
    move=> tmp. move: tmp pf_v pf => <- pf_v _.
    f_equal.
    apply: proof_irrelevance.
Qed.

Definition Δ1_D_iso_fwd : Δ 1 -> 'D_1 :=
    fun '(exist x pf) => 
        exist _ (x ord0) (Logic.eq_trans (expr2 (x ord0)) (pf ord0 ord0)).

Definition Δ1_D_iso_bwd : 'D_1 -> Δ 1.
    refine(fun '(exist x pf) => exist _ (fun=> x) _).
    by rewrite -expr2 pf.
Defined.

Definition Δ1_D_iso : Δ 1 <--> 'D_1.
Proof.
    refine ({|
        Forward := Δ1_D_iso_fwd;
        Backward := Δ1_D_iso_bwd;
    |}).
    - move=> [x pf].
        apply: eq_Δ => i.
        rewrite /Rn_of_Δ ffunE ffunE; simpl.
        by rewrite ord1.
    - move=> [x pf]; simpl.
        f_equal.
        apply: proof_irrelevance.
Defined.

(* Implementation at the bottom of the file. *)
Module Type Δ_properties_sig.
    Axiom KL_Δ : forall {n}, (Δ n -> R) <--> (R * ('I_n -> R)).
    Axiom KL_Δ_back : forall {n coeffs} {d : Δ n},
        Backward _ _ KL_Δ coeffs d = fst coeffs + \sum_i snd coeffs i * d i.
End Δ_properties_sig.













(* Implementation details that involve Weil algebras *)
Require Import weil_algebras.

Module Export nilpotent_properties : nilpotent_properties_sig.

    Definition nilpotent_weil_mul n
    (x : {ffun 'I_n -> R^o}) (y : {ffun 'I_n -> R^o}) : {ffun 'I_n -> R^o} :=
    [ffun k : 'I_n => \sum_(i < n) \sum_(j < n) 
        if (i + j + 1 == k)%N then x i * y j else 0
    ].

    Fact nilpotent_weil_mul_linear n z : linear (nilpotent_weil_mul n z).
    Proof.
        move=> a x y.
        rewrite /nilpotent_weil_mul {2}/*:%R {4}/+%R; simpl.
        rewrite /ffun_scale /ffun_add.
        apply: eq_ffun => k.
        do 3 rewrite ffunE.

        rewrite scaler_sumr -big_split; simpl.
        apply: eq_bigr => i _.
        rewrite scaler_sumr -big_split; simpl.
        apply: eq_bigr => j _.

        case: (i + j + 1 == k)%N.
        + rewrite scalerAr -mulrDr.
            rewrite {1}/*:%R {1}/+%R; simpl.
            by rewrite /ffun_scale /ffun_add ffunE ffunE.
        + by rewrite addr0 scaler0.
    Qed.

    Fact nilpotent_weil_mul_assoc n : associative (nilpotent_weil_mul n).
    Proof.
        move=> x y z.
        apply: eq_ffun => k.

        suff: (\sum_(i < n)\sum_(j < n)\sum_(i' < n)\sum_(j' < n)
            if (i + j + 1 == k)%N && (i' + j' + 1 == j)%N then x i * y i' * z j' else 0)
            = (\sum_(i < n)\sum_(j < n)\sum_(i' < n)\sum_(j' < n)
            if (i + j + 1 == k)%N && (i' + j' + 1 == i)%N then x i' * y j' * z j else 0).
        {
            move=> eq.
            apply: Logic.eq_trans; first (apply: Logic.eq_trans; last apply: eq).
            - apply: eq_bigr => i _.
                apply: eq_bigr => j _.
                case: (i + j + 1 == k)%N.
                + rewrite ffunE mulr_sumr.
                    apply: eq_bigr => i' _.
                    rewrite mulr_sumr.
                    apply: eq_bigr => j' _.
                    case: (i' + j' + 1 == j)%N; simpl.
                    * by rewrite mulrA.
                    * by rewrite mulr0.
                + rewrite big1; first done.
                    move=> i' _.
                    by rewrite big1.
            - apply: eq_bigr => i _.
                apply: eq_bigr => j _.
                case: (i + j + 1 == k)%N.
                + rewrite ffunE mulr_suml.
                    apply: eq_bigr => i' _.
                    rewrite mulr_suml.
                    apply: eq_bigr => j' _.
                    case: (i' + j' + 1 == i)%N; simpl.
                    * done.
                    * by rewrite mul0r.
                + rewrite big1; first done.
                    move=> i' _.
                    by rewrite big1.
        }

        have: forall (P : 'I_n -> bool) F, (forall a b, P a -> P b -> a = b)
            -> (\sum_(i | P i) F i) = if [pick x | P x] is Some x then F x else 0.
        {
            move=> t P F eq.
            case pickP.
            - move=> i pf.
                apply: big_pred1.
                move=> j.
                apply/idP/eqP.
                + move=> pf2.
                    by apply: eq.
                + by move=>->.
            - move=> H.
                by apply: big_pred0.
        }
        move=> lemma.

        suff: (\sum_(i < n)\sum_(i' < n)\sum_(j' < n)
            if (i + (i' + j' + 1) + 1 == k)%N then x i * y i' * z j' else 0)
            = (\sum_(j < n)\sum_(i' < n)\sum_(j' < n)
            if ((i' + j' + 1) + j + 1 == k)%N then x i' * y j' * z j else 0).
        {
            move=> eq.
            apply: Logic.eq_trans; first (apply: Logic.eq_trans; last apply: eq).
            - apply: eq_bigr => i _.
                rewrite exchange_big; apply: eq_bigr => i' _.
                rewrite exchange_big; apply: eq_bigr => j' _.
                rewrite -big_mkcond.
                rewrite lemma.
                + case pickP.
                    * move=> j /andP [/eqP <- /eqP <-].
                        by case eqP.
                    * case: eqP; last done.
                        move=> eq2.
                        suff: (i' + j' + 1)%N < n.
                        {
                            move=> less nope.
                            suff: false by [].
                            rewrite -(nope (Ordinal less)) -eq2; simpl.
                            by apply /andP.
                        }
                        apply: leq_ltn_trans; last exact: ltn_ord k.
                        rewrite -eq2.
                        apply: leq_trans; [apply: leq_addl | apply: leq_addr].
                + move=> [a a_pf] [b b_pf] /andP [/eqP <- /eqP ->] /andP [_ /eqP eq2].
                    simpl in eq2.
                    move: eq2 a_pf => -> a_pf.
                    f_equal.
                    apply: proof_irrelevance.
            - apply: Logic.eq_sym.
                rewrite exchange_big; apply: eq_bigr => j _.
                rewrite exchange_big; apply: eq_bigr => i' _.
                rewrite exchange_big; apply: eq_bigr => j' _.
                rewrite -big_mkcond.
                rewrite lemma.
                + case pickP.
                    * move=> i /andP [/eqP <- /eqP <-].
                        by case eqP.
                    * case: eqP; last done.
                        move=> eq2.
                        suff: (i' + j' + 1)%N < n.
                        {
                            move=> less nope.
                            suff: false by [].
                            rewrite -(nope (Ordinal less)) -eq2; simpl.
                            by apply /andP.
                        }
                        apply: leq_ltn_trans; last exact: ltn_ord k.
                        rewrite -eq2.
                        apply: leq_trans; [apply: leq_addr | apply: leq_addr].
                + move=> [a a_pf] [b b_pf] /andP [/eqP <- /eqP ->] /andP [_ /eqP eq2].
                    simpl in eq2.
                    move: eq2 a_pf => -> a_pf.
                    f_equal.
                    apply: proof_irrelevance.
        }

        apply: Logic.eq_sym.
        rewrite exchange_big; apply: eq_bigr => i _.
        rewrite exchange_big; apply: eq_bigr => i' _.
        apply: eq_bigr => j' _.

        suff: (i + i' + 1 + j')%N = (i + (i' + j' + 1))%N by move=>->.
        do 3 rewrite -addnA; f_equal; f_equal.
        apply: addnC.
    Qed.

    Fact nilpotent_weil_mul_comm n : commutative (nilpotent_weil_mul n).
    Proof.
        move=> x y.
        apply: eq_ffun => k.
        rewrite exchange_big.
        apply: eq_bigr => i _.
        apply: eq_bigr => j _.
        move: (addnC j i) => ->.
        move: (mulrC (x j) (y i)) => ->.
        done.
    Qed.

    Fact nilpotent_weil_mul_nilpotent n x (xs : 'I_n -> _) 
        : \big[nilpotent_weil_mul n/x]_i xs i = 0.
    Proof.
        move: xs.

        suff: forall m xs (mn : m <= n) (i : 'I_n), i < m -> 
            (\big[nilpotent_weil_mul n/x]_(j : 'I_m) xs j) i = 0.
        {
            move=> H xs.
            rewrite -ffunP => i.
            rewrite (H n xs (leqnn n) i (ltn_ord i)).
            by rewrite /0; simpl; rewrite /ffun_zero ffunE.
        }

        elim=> [|m IH] xs mn k pf.
        - done.
        - rewrite big_ord_recl ffunE.
            rewrite big1; first done.
            move=> i _.
            rewrite big1; first done.
            move=> j _.
            case eqP; last done.
            move=> eq.

            suff: (\big[nilpotent_weil_mul n/x]_(i0 < m) 
                xs (lift ord0 i0)) j = 0.
            {
                move=>->.
                apply: mulr0.
            }

            apply: IH.
            - by apply: ltnW.
            - rewrite -ltnS.
                apply: leq_ltn_trans; last apply: pf.
                move: eq => <-.
                rewrite addn1.
                apply leq_addl.
    Qed.

    Definition nilpotent_weil (n : nat) : WeilData R := {|
        Domain := [finType of 'I_n];
        Mul := nilpotent_weil_mul n;
        Lin := nilpotent_weil_mul_linear n;
        Assoc := nilpotent_weil_mul_assoc n;
        Comm := nilpotent_weil_mul_comm n;
        Nilpotent := ex_intro _ _ (nilpotent_weil_mul_nilpotent n);
    |}.

    Definition nilpotent_weil_iso {n} : 'D_n <--> Spec R (Weil (nilpotent_weil n)).
        apply: iso_trans; last apply: reifyIso.
        refine({|
            Forward '(exist x pf) := exist _ [ffun i => x] _;
            Backward '(exist f morph) := exist _ (\sum_i f i) _;
        |}).
    Admitted.

    Definition KL_D {n} : (nilpotent n -> R) <--> ('I_n -> R).
    Admitted.
    Theorem KL_D_back {n coeffs} {d : nilpotent n}
        : Backward _ _ KL_D coeffs d = polynomial coeffs d.
    Admitted.
End nilpotent_properties.

Module Export Δ_properties : Δ_properties_sig.
    Definition Δ_weil (n : nat) : WeilData R.
    Proof.
        refine ({|
            Domain := [finType of 'I_n];
            Mul := fun _ _ => 0;
            Lin := fun _ _ _ _ => Logic.eq_sym (Logic.eq_trans (addr0 _) (scaler0 _ _));
            Assoc := fun _ _ _ => Logic.eq_refl;
            Comm := fun _ _ => Logic.eq_refl;
            Nilpotent := ex_intro _ [finType of 'I_1] _
        |}).
        move=> x xs.
        apply: big_ord_recl.
    Defined.

    Definition KL_Δ {n} : (Δ n -> R) <--> (R * ('I_n -> R)).
    Admitted.
    Theorem KL_Δ_back {n coeffs} {d : Δ n}
        : Backward _ _ KL_Δ coeffs d = fst coeffs + \sum_i snd coeffs i * d i.
    Admitted.
End Δ_properties.