Require Import List Arith Lia Morphisms Setoid.
From Undecidability.HOU Require Import calculus.calculus.
From Undecidability.HOU Require Import
        unification.higher_order_unification unification.nth_order_unification
        concon.conservativity_constants concon.conservativity.
Import ListNotations.

Set Default Proof Using "Type".

Section Retracts.
    Variable (X Y: Const).
    Variable (RE: retract X Y).
    Hypothesis consts_agree: x: X, ctype X x = ctype Y (I x).

    Let inj (c: X) := const (I c).
    Let re (f: Y ) (d: Y) :=
      match tight RE d with
      | Some x const x
      | None inhab X (f d) (ctype Y d)
      end.

    Lemma inj_ren delta: inj >> ren = inj.
    Proof.
      fext; reflexivity.
    Qed.


    Lemma re_ren f delta: re f >> ren = re (f >> ).
    Proof.
      fext; intros x; unfold funcomp, re.
      destruct tight; eauto. now rewrite inhab_ren.
    Qed.


    Lemma subst_consts_inject_forward sigma s:
      subst_consts inj ( s) =
      ( >> subst_consts inj) (subst_consts inj s).
    Proof.
      induction s in |-*; cbn in *; intuition.
      - f_equal. rewrite inj_ren.
        rewrite IHs. f_equal.
        fext; intros []; cbn; eauto.
        rewrite inj_ren with ( := shift) at 1.
        unfold funcomp at 2. now rewrite ren_subst_consts_commute.
      - now rewrite , .
    Qed.


    Lemma subst_consts_inject_backwards sigma s f:
      subst_consts (re f) ( subst_consts inj s) =
      ( >> subst_consts (re f)) s.
    Proof.
      induction s in , f |-*; cbn.
      - reflexivity.
      - unfold funcomp. unfold re. unfold tight. rewrite retr.
        destruct (I c == I c); intuition.
      - f_equal.
        rewrite subst_consts_up, inj_ren, re_ren; eauto.
      - rewrite , ; eauto.
    Qed.


    Lemma inj_typing n sigma Gamma Delta:
        ⊩(n) :
        ⊩(n) >> subst_consts inj : .
    Proof using consts_agree.
      intros ????. eapply ordertyping_preservation_consts; eauto.
      intros ??; unfold inj.
      rewrite consts_agree.
      econstructor; rewrite consts_agree.
      eapply typing_constants; eauto.
    Qed.


    Lemma re_typing n Delta f sigma Gamma:
      1 n
       ⊩(n) :
      ( x c, x dom c consts ( x) nth (f c) = Some (target (ctype Y c)))
       ⊩(n) >> subst_consts (re f) : .
    Proof using consts_agree.
      intros L T Sub x A H; unfold funcomp.
      eapply ordertyping_preservation_consts; eauto.
      intros y H'.
      unfold re. destruct (tight RE y) eqn: EQr.
      - eapply tight_is_tight in EQr. subst. rewrite consts_agree.
        econstructor. rewrite consts_agree. eapply typing_constants; eauto.
      - eapply ordertyping_monotone, inhab_typing; domin H; eauto.
    Qed.


    Program Instance unification_retract {n} (I: orduni n X) : orduni n Y :=
      {
        s₀ := subst_consts inj s₀;
        t₀ := subst_consts inj t₀;
        A₀ := A₀;
        Gamma₀ := Gamma₀;
      }.
    Next Obligation.
      eapply ordertyping_preservation_consts. eapply H1₀.
      intros x . unfold inj. rewrite consts_agree.
      econstructor. rewrite consts_agree.
      eapply typing_constants. eapply H1₀. eauto.
    Qed.
    Next Obligation.
      eapply ordertyping_preservation_consts. eapply H2₀.
      intros x . unfold inj. rewrite consts_agree.
      econstructor. rewrite consts_agree.
      eapply typing_constants. eapply H2₀. eauto.
    Qed.

    Lemma unification_retract_forward n (I: orduni n X):
      OU n X I OU n Y (unification_retract I).
    Proof.
      intros ( & & T & EQ).
      exists . exists ( >> subst_consts inj). split.
      - now eapply inj_typing.
      - unfold s₀, t₀; cbn.
        rewrite !subst_consts_inject_forward.
        now rewrite EQ.
    Qed.


    Lemma unification_retract_backward n (I: orduni n X):
      1 n OU n Y (unification_retract I) OU n X I.
    Proof.
      intros Leq ( & & T & EQ).
      pose (C := Consts (map (nats (length Gamma₀)))).
      pose (f y := match find y C with
                  | Some x length + x
                  | None 0
                  end).
      exists ( target' (map (ctype Y) C)).
      exists ( >> subst_consts (re f)). split.
      - eapply re_typing; eauto.
        intros ???. eapply weakening_ordertyping_app; eauto.
        intros x y . eapply Consts_consts with (S := map (nats (| Gamma₀ |))) in ;
                            eauto using in_map, lt_nats.
        unfold f, C. eapply find_in in as [? ]; rewrite .
        rewrite nth_error_app2; simplify; eauto.
        unfold target'; erewrite map_map, map_nth_error; simplify; eauto.
        now eapply find_Some.
      - unfold s₀, t₀ in EQ; cbn in EQ.
        now rewrite !subst_consts_inject_backwards, EQ.
    Qed.


    Lemma unification_constants_monotone n:
      1 n OU n X OU n Y.
    Proof using re inj consts_agree RE.
      intros H; exists unification_retract.
      intros I; split;
        eauto using unification_retract_forward, unification_retract_backward.
    Qed.


End Retracts.

Section RemoveConstants.

  Variable (X Y: Const) (RE: retract Y X).

  Hypothesis (consts_agree: y, ctype Y y = ctype X (I y)).

  Let R' x := tight RE x.

  Let enc_const (A: list X) (x: X): exp Y :=
    match R' x with
    | Some y const y
    | None
      match find x A with
      | Some n var n
      | None var 0
      end
    end.

  Let enc_var (A: list X) (y: ) : exp X :=
    AppR (var (y + length A)) (map var (nats (length A))).


  Let enc_term (C: list X) (s: exp X): exp Y :=
    Lambda (length C) (subst_consts (enc_const C) (enc_var C s)).

  Let enc_type (C: list X) (A: type): type :=
    Arr (rev (map (ctype X) C)) A.

  Let enc_ctx (C: list X) (Gamma: ctx): ctx :=
    map (enc_type C) .

  Let enc_subst (C: list X) (sigma: fin exp X) (x: ) :=
    enc_term C ( x).

  Let ι (y: Y) : exp X := const (I y).

  Let inv_term C (s: exp Y) :=
    AppR (subst_consts ι s) (map const C).

  Let inv_subst C (sigma: fin exp Y) (x: ) :=
    inv_term C ( x).
    Set Default Proof Using "Type".

  Section EncodingLemmas.
    Variable (C: list X) (n: ).
    Hypothesis (O: ord' (map (ctype X) C) < n).

    Lemma remove_constants_ordertyping Gamma s A:
       ⊢(n) s : A
      ( x, x consts s R' x = None x C)
      enc_ctx C ⊢(n) enc_term C s : enc_type C A.
    Proof using consts_agree O.
      intros T H. eapply Lambda_ordertyping; simplify; eauto.
      eapply ordertyping_preservation_consts.
      eapply ordertyping_weak_preservation_under_substitution; eauto.
      - intros y B . unfold enc_var.
        eapply AppR_ordertyping.
        + eapply map_var_typing with (L := map (ctype X) C).
          * intros x; rewrite dom_lt_iff; simplify.
            intros ? % nats_lt; .
          * rewrite select_nats.
            rewrite firstn_app; simplify.
            rewrite firstn_all; cbn; now simplify.
          * eauto.
        + econstructor; simplify; intuition.
          eapply vars_ordertyping_nth with (n := n) ( := )
            in ; eauto.
          unfold enc_ctx;
            erewrite nth_error_app2, map_nth_error; simplify; now eauto.
      - intros x H'. unfold enc_const.
        eapply consts_subst_in in H' as [].
        destruct (R' x) eqn: EQ.
        + eapply tight_is_tight in EQ as EQ'; subst x.
          rewrite consts_agree. econstructor.
          rewrite consts_agree. eapply typing_constants; eauto.
        + destruct find eqn: .
          * eapply find_Some in .
            econstructor. rewrite O.
            now eapply ord'_in, in_map, H.
            rewrite nth_error_app1; simplify;
              eauto using nth_error_Some_lt.
            erewrite map_nth_error; now eauto.
          * exfalso.
            eapply find_not_in in ; intuition.
        + unfold enc_var in . destruct . intuition.
          rewrite consts_AppR in . simplify in .
          unfold Consts in ; eapply in_flat_map in as [].
          intuition. mapinj. cbn in ; intuition.
    Qed.


  Lemma inv_term_typing Gamma s A:
     ⊢(n) s : enc_type C A
     ⊢(n) inv_term C s : A.
  Proof using consts_agree O.
    intros H; unfold inv_term.
    eapply AppR_ordertyping with (L := map (ctype X) C).
    eapply const_ordertyping_list. rewrite O; eauto.
    eapply ordertyping_preservation_consts; [eauto|].
    intros y ?; rewrite consts_agree.
    econstructor. rewrite consts_agree.
    eapply typing_constants; eauto.
  Qed.


  Lemma remove_constants_ordertypingSubst Delta sigma Gamma :
      ⊩(n) :
    ( x c, c consts ( x) R' c = None c C)
    enc_ctx C ⊩(n) enc_subst C : enc_ctx C .
  Proof using consts_agree O.
    intros ?????. unfold enc_ctx in . rewrite nth_error_map_option in .
    destruct nth eqn: EQ; try discriminate; injection as .
    eapply remove_constants_ordertyping; eauto.
  Qed.


  Lemma inv_subst_typing Delta sigma Gamma:
      ⊩(n) : enc_ctx C
      ⊩(n) inv_subst C : .
  Proof using consts_agree O.
    intros ????. eapply inv_term_typing, H.
    unfold enc_ctx; erewrite map_nth_error; eauto.
  Qed.


  Unset Default Proof Using.

  Global Instance enc_proper:
    Proper (equiv (@step X) ++> equiv (@step Y)) (enc_term C).
  Proof.
    intros ?? H; unfold enc_term; now rewrite H.
  Qed.


  Global Instance inv_proper:
    Proper (equiv (@step Y) ++> equiv (@step X)) (inv_term C).
  Proof.
    intros ?? H; unfold inv_term; now rewrite H.
  Qed.


  Set Default Proof Using "Type".


  Lemma subst_consts_subst Z (s: exp X) sigma tau theta zeta (kappa: X exp Z):
    ( x, x vars s subst_consts ( x) >* subst_consts ( x))
    ( x, x consts s x >* x)
     subst_consts ( s) >* subst_consts ( s).
  Proof using ι n inv_term inv_subst enc_term enc_subst enc_const Y RE R' C.
    induction s in , , , , |-*.
    - cbn; intros; eapply H; now econstructor.
    - cbn; intros; eapply ; eauto.
    - cbn -[vars]; intros.
      rewrite IHs with ( := >> ren shift) ( := up ); eauto.
      + intros []; cbn; eauto.
        unfold funcomp at 2.
        rewrite ren_subst_consts_commute.
        unfold funcomp at 2. rewrite ren_subst_consts_commute.
        unfold up. asimpl.
        erewrite compSubstSubst_exp; try reflexivity.
        intros; eapply subst_steps, H. eauto.
      + intros x. unfold funcomp.
        asimpl. erewrite compSubstSubst_exp; try reflexivity.
        intros; eapply subst_steps, ; eauto.
    - intros; cbn; rewrite , ; try reflexivity.
      1, 3: intros; eapply H; eauto.
      all: intros; eapply ; cbn; simplify; intuition.
  Qed.


  Lemma enc_subst_term_reduce tau s:
    ( x c, c consts ( x) R' c = None c C)
    ( x, x consts s R' x = None x C)
    enc_subst C enc_term C s >* enc_term C ( s).
  Proof using n.
    intros ; unfold enc_term. asimpl. eapply Lambda_steps_proper.
    rewrite subst_consts_subst; eauto.
    - intros x ?. unfold funcomp at 1.
      unfold enc_var.
      rewrite !subst_consts_AppR, AppR_subst; cbn.
      rewrite it_up_ge; simplify; eauto.
      rewrite map_id_list; cbn.
      rewrite map_map; cbn. change ( x @var Y x) with (@var Y).
      unfold enc_subst at 1, enc_term. rewrite Lambda_ren.
      rewrite AppR_Lambda'; simplify; eauto.
      asimpl. rewrite subst_consts_subst; eauto.
      + intros ? ?; unfold funcomp.
        unfold enc_var. rewrite idSubst_exp; eauto.
        intros y; cbn.
        destruct (le_lt_dec (length C) y).
        rewrite it_up_ren_ge, le_plus_minus_r, sapp_ge_in; simplify; eauto.
        erewrite it_up_ren_lt, nth_error_sapp; eauto.
        erewrite map_nth_error; eauto using nth_nats.
      + unfold enc_const; intros c; destruct (R' c) eqn: ?; cbn; eauto.
        intros [m H'] % % find_in; eauto; rewrite H'.
        eapply find_Some, nth_error_Some_lt in H'.
        cbn; unfold funcomp; erewrite it_up_ren_lt, nth_error_sapp; eauto.
        erewrite map_nth_error; eauto using nth_nats.
      + intros ? ?; mapinj; mapinj; cbn; rewrite it_up_lt; eauto using nats_lt.
    - unfold enc_const; intros c; destruct (R' c) eqn: ?; cbn; eauto.
      intros [m H'] % % find_in; eauto; rewrite H'.
      eapply find_Some, nth_error_Some_lt in H'.
      cbn; erewrite it_up_lt; eauto.
  Qed.


  Lemma enc_term_app sigma s:
    ( x, x consts s R' x = None x C)
    inv_term C ( enc_term C s) >* inv_subst C s.
  Proof using n.
    intros H. unfold enc_term, inv_term.
    asimpl. rewrite subst_consts_Lambda.
    rewrite AppR_Lambda'; simplify; eauto.
    replace (ι >> _) with ι by (fext; intros ?; reflexivity).
    pose ( x := AppR
              (subst_consts ι (ren (plus (length C)) ( x)))
              (map var (nats (length C)))).
    erewrite subst_consts_subst with ( := enc_const C) ( := ).
    - rewrite subst_consts_comp.
      rewrite subst_consts_subst with ( := const) ( := inv_subst C ) .
      rewrite subst_consts_ident; eauto.
      + intros x V. unfold , inv_subst.
        rewrite subst_consts_AppR, subst_consts_comp.
        rewrite map_id_list.
        2: intros ??; mapinj; reflexivity.
        rewrite subst_consts_ident; eauto.
        replace (ι >> _) with (ι >> ren (plus (length C))).
        2: fext; intros c; unfold funcomp, enc_const, R';
          cbn; now rewrite tight_retr.
        rewrite ren_subst_consts_commute. unfold inv_term. asimpl.
        eapply refl_star. f_equal.
        * eapply idSubst_exp. intros y; unfold funcomp.
          erewrite sapp_ge_in; simplify; eauto.
        * clear V. eapply list_pointwise_eq.
          intros m; rewrite !nth_error_map_option.
          destruct (le_lt_dec (length C) m) as [|].
          -- edestruct nth_error_None as [_ ].
             edestruct nth_error_None as [_ ].
             all: cbn; simplify; eauto.
          -- rewrite nth_nats; eauto; cbn.
             destruct (nth_error_lt_Some _ m C) as [c ]; eauto.
             rewrite ; cbn. erewrite nth_error_sapp; eauto.
             erewrite map_nth_error; eauto.
      + intros ??. unfold funcomp.
        unfold enc_const.
        destruct (R' x) eqn: ?. cbn.
        eapply tight_is_tight in Heqo; now subst.
        eapply H in ; eauto.
        eapply find_in in as [m ]; rewrite .
        cbn. erewrite nth_error_sapp; eauto.
        erewrite map_nth_error; eauto.
        eapply find_Some; eauto.
    - intros; unfold enc_var, .
      rewrite subst_consts_AppR; cbn.
      rewrite AppR_subst; cbn; rewrite it_up_ge; eauto; simplify.
      rewrite subst_consts_AppR, subst_consts_comp, subst_consts_ident.
      2: intros ?; unfold funcomp, enc_const, R'; cbn; now rewrite tight_retr.
      eapply refl_star. f_equal.
      rewrite map_id_list; eauto.
      intros ??; mapinj; mapinj; cbn.
      rewrite it_up_lt; eauto using nats_lt.
    - intros. unfold enc_const. destruct (R' x) eqn: ?. cbn.
      eapply tight_is_tight in Heqo; now subst.
      eapply H in ; eauto.
      eapply find_in in as [m ]; rewrite .
       eapply find_Some, nth_error_Some_lt in .
      cbn; erewrite it_up_lt; eauto.
  Qed.


  Lemma enc_inv_motivation s:
    ( x, x consts s R' x = None x C)
    inv_term C (enc_term C s) >* ( x AppR (var x) (map const C)) s.
  Proof using n.
    intros H. replace (enc_term C s) with (var enc_term C s) by now asimpl.
    rewrite enc_term_app; eauto.
  Qed.



  End EncodingLemmas.

  Definition iConsts {n} (I: orduni n X) :=
    filter ( x if R' x == None then true else false)
           (Consts [s₀; t₀]).


  Program Instance remove_constants n (I: orduni n X)
          (H: ord' (map (ctype X) (iConsts I)) < n) : orduni n Y :=
    {
      s₀ := enc_term (iConsts I) s₀;
      t₀ := enc_term (iConsts I) t₀;
      A₀ := enc_type (iConsts I) A₀;
      Gamma₀ := enc_ctx (iConsts I) Gamma₀;
    }.
  Next Obligation.
    eapply remove_constants_ordertyping; eauto using H1₀.
    cbn; simplify; intuition.
    eapply filter_In; destruct eq_dec; intuition.
  Qed.
  Next Obligation.
    eapply remove_constants_ordertyping; eauto using H2₀.
    cbn; simplify; intuition.
    eapply filter_In; destruct eq_dec; intuition.
  Qed.


  Lemma remove_constants_forward n (I: orduni n X)
        (H: ord' (map (ctype X) (iConsts I)) < n):
    OU n X I OU n Y (remove_constants n I H).
  Proof.
    assert (1 n) as L by .
    destruct I as [ s t A ]; intros ( & & T' & E'); cbn in *.
    eapply downcast_constants_ordered in T'
      as ( & & T & E & _ & Cs); eauto; clear E'.
    pose (C := filter ( x if R' x == None then true else false)
          (Consts [s; t])).
    exists (enc_ctx C ( )).
    exists (enc_subst C ). split.
    - eapply remove_constants_ordertypingSubst; eauto.
      intros ? ? ? % Cs. intros ?; cbn; eapply filter_In.
      intuition. destruct eq_dec; intuition.
    - cbn [s₀ t₀ remove_constants iConsts Consts flat_map];
        simplify; unfold C.
      cbn in Cs; simplify in Cs.
      rewrite !enc_subst_term_reduce; eauto; intuition.
      now rewrite E.
      all: eapply filter_In; destruct eq_dec; cbn; intuition.
      all: rewrite app_nil_r; eauto.
  Qed.


  Lemma remove_constants_backward n (I: orduni n X)
        (H: ord' (map (ctype X) (iConsts I)) < n):
    OU n Y (remove_constants n I H) OU n X I.
  Proof.
    pose (C := iConsts I).
    destruct I as [ s t A ]; intros ( & & T & EQ).
    exists . exists (inv_subst C ). split; eauto using inv_subst_typing.
    rewrite !enc_term_app. cbn [s₀ t₀ remove_constants] in EQ.
    unfold C; now rewrite EQ. all: eauto.
    all: intros; eapply filter_In; cbn; intuition; destruct eq_dec; intuition.
  Qed.


  Lemma remove_constants_reduction n:
    1 n
    ( x, tight RE x = None ord (ctype X x) < n) OU n X OU n Y.
  Proof using consts_agree.
    intros L ?.
    assert ( I: orduni n X, ord' (map (ctype X) (iConsts I)) < n) as O.
    - intros ?; destruct n; try .
      eapply le_n_S, ord'_elements.
      intros; mapinj. cbn in ; simplify in .
      eapply filter_In in as [].
      destruct (R' x == None); try discriminate.
      now eapply le_S_n, H.
    - exists ( I remove_constants n I (O I)).
      intros I; split.
      eapply remove_constants_forward.
      eapply remove_constants_backward.
  Qed.


End RemoveConstants.