Lvc.Lowering.EAE

Require Import Util LengthEq AllInRel Map SetOperations.

Require Import Val EqDec Computable Var Env IL Annotation AppExpFree.
Require Import Liveness.
Require Import SimF Fresh Filter.

Set Implicit Arguments.
Unset Printing Records.

Section MapUpdate.
  Open Scope fmap_scope.
  Variable X : Type.
  Context `{OrderedType X}.
  Variable Y : Type.

  Fixpoint update_with_list' XL YL (m:X Y) :=
    match XL, YL with
      | x::XL, y::YLupdate_with_list' XL YL (m[x <- y])
      | _, _m
    end.

  Lemma update_unique_commute (XL:list X) (VL:list Y) E D x y
  : length XL = length VL
     unique (x::XL)
     agree_on eq D (E [x <- y] [XL <-- VL]) (E [XL <-- VL] [x <- y]).
  Proof.
    intros LEN UNIQ. length_equify.
    general induction LEN; simpl in × |- *; dcr; simpl in *; eauto.
    hnf; intros. lud.
    - exfalso; eauto.
    - etransitivity; [eapply IHLEN|]; eauto; lud.
    - etransitivity; [eapply IHLEN|]; eauto; lud.
  Qed.

  Lemma update_with_list_agree' (XL:list X) (VL:list Y) E D
  : length XL = length VL
     unique XL
     agree_on eq D (E [XL <-- VL]) (update_with_list' XL VL E).
  Proof.
    intros LEN UNIQ. length_equify.
    general induction LEN; simpl in *; eauto.
    - etransitivity. symmetry. eapply update_unique_commute; eauto using length_eq_length.
      eapply IHLEN; eauto.
  Qed.
End MapUpdate.

Fixpoint list_to_stmt (xl: list var) (Y : list op) (s : stmt) : stmt :=
  match xl, Y with
  | x::xl, e :: YstmtLet x (Operation e) (list_to_stmt xl Y s)
  | _, _s
  end.

Lemma list_to_stmt_correct L E s xl Y vl
: length xl = length Y
   omap (op_eval E) Y = Some vl
   unique xl
   disj (of_list xl) (list_union (List.map Op.freeVars Y))
   star2 F.step (L, E, list_to_stmt xl Y s) nil
          (L, update_with_list' xl (List.map Some vl) E, s).
Proof.
  intros Len Eq Uni Disj.
  length_equify.
  general induction Len; simpl in × |- *; eauto using star2_refl.
  - simpl in ×. monad_inv Eq.
    eapply star2_silent.
    econstructor; eauto.
    rewrite list_union_start_swap in Disj.
    eapply IHLen; eauto.
    eapply omap_op_eval_agree; eauto.
    symmetry. eapply agree_on_update_dead; [|reflexivity].
    eauto with cset.
    eapply disj_1_incl; [ eapply disj_2_incl |]; eauto with cset.
Qed.

Lemma list_to_stmt_crash L E xl Y s
: length xl = length Y
   omap (op_eval E) Y = None
   unique xl
   disj (of_list xl) (list_union (List.map Op.freeVars Y))
   σ, star2 F.step (L, E, list_to_stmt xl Y s) nil σ
          state_result σ = None
          normal2 F.step σ.
Proof.
  intros. eapply length_length_eq in H.
  general induction H; simpl in × |- ×.
  - monad_inv H0.
    + eexists. split. eapply star2_refl.
      split; eauto. stuck2.
    + rewrite list_union_start_swap in H2.
      edestruct (IHlength_eq L (E [x <- Some x0])); eauto.
      × eapply omap_op_eval_agree; eauto. symmetry.
        eapply agree_on_update_dead; [|reflexivity].
        intro. eapply (H2 x); cset_tac.
      × eapply disj_1_incl; [ eapply disj_2_incl |]; eauto with cset.
      × dcr. eexists. split; eauto.
        eapply star2_silent.
        econstructor; eauto. eauto.
Qed.

Fixpoint replace_if X (p:X bool) (L:list X) (L':list X) :=
  match L with
  | x::Lif p x then
              match L' with
              | y::L'y::replace_if p L L'
              | _nil
              end
            else
              x::replace_if p L L'
  | _nil
  end.

Local Notation "'IsVar'" := (fun eB[isVar e]).
Local Notation "'NotVar'" := (fun eB[¬ isVar e]).

Fixpoint compile s {struct s}
  : stmt :=
  match s with
    | stmtLet x e sstmtLet x e (compile s)
    | stmtIf x s tstmtIf x (compile s) (compile t)
    | stmtApp l Y
      let Y' := List.filter NotVar Y in
      let xl := fresh_list fresh (list_union (List.map Op.freeVars Y)) (length Y') in
      list_to_stmt xl Y' (stmtApp l (replace_if NotVar Y (Var xl)))
    | stmtReturn xstmtReturn x
    | stmtFun F tstmtFun (List.map (fun Zs(fst Zs, compile (snd Zs))) F) (compile t)
  end.

Lemma omap_lookup_vars V xl vl
  : length xl = length vl
     unique xl
     omap (op_eval (V[xl <-- List.map Some vl])) (List.map Var xl) = Some vl.
Proof.
  intros. eapply length_length_eq in H.
  general induction H; simpl; eauto.
  lud; isabsurd; simpl.
  erewrite omap_op_eval_agree; try eapply IHlength_eq; eauto; simpl in *; intuition.
  instantiate (1:= V [x <- Some y]).
  eapply update_unique_commute; eauto; simpl; intuition.
Qed.

Fixpoint merge_cond (Y:Type) (K:list bool) (L:list Y) (L':list Y) :=
  match K, L, L' with
  | true::K, x::L, L'x::merge_cond K L L'
  | false::K, L, y::L'y::merge_cond K L L'
  | _, _, _nil
  end.

Lemma omap_filter_none X Y (f:Xoption Y) (p:Xbool) (L:list X)
  : omap f (List.filter p L) = None
     omap f L = None.
Proof.
  general induction L; intros; simpl in ×.
  cases in H; simpl in *; try monad_inv H.
  - rewrite H0; simpl; eauto.
  - rewrite EQ; simpl. erewrite H, IHL; eauto.
  - destruct (f a); simpl; eauto.
    erewrite IHL; eauto.
Qed.

Lemma omap_filter_partitions X Y (f:Xoption Y) (p q:Xbool) (L:list X) vl1 vl2
  : omap f (List.filter p L) = Some vl1
     omap f (List.filter q L) = Some vl2
     ( n x, get L n x negb (p x) = q x)
     omap f L = Some (merge_cond (p L) vl1 vl2).
Proof.
  general induction L; intros; simpl in *; eauto.
  cases in H; simpl in *; try monad_inv H.
  - erewrite <- H1 in H0; eauto using get.
    rewrite <- Heq in H0. simpl in ×.
    rewrite EQ. simpl.
    rewrite (IHL _ _ _ _ _ _ EQ1 H0); eauto using get.
  - erewrite <- H1 in H0; eauto using get.
    rewrite <- Heq in H0. simpl in ×.
    monad_inv H0. rewrite EQ. simpl.
    rewrite (IHL _ _ _ _ _ _ H EQ1); eauto using get.
Qed.

Lemma omap_replace_if V Y Y' vl0 vl1
  : omap (op_eval V) (List.filter IsVar Y) = vl1
      omap (op_eval V) Y' = vl0
      omap
         (op_eval V)
         (replace_if NotVar Y Y') = merge_cond (IsVar Y) vl1 vl0 .
Proof.
  general induction Y; simpl; eauto.
  simpl in ×. cases in H; cases; isabsurd; simpl in ×.
  - monad_inv H. rewrite EQ. simpl. erewrite IHY; eauto. eauto.
  - destruct Y'; simpl in ×.
    + inv H0; eauto.
    + monad_inv H0. rewrite EQ. simpl.
      erewrite IHY; eauto. simpl; eauto.
Qed.

Instance SR : PointwiseProofRelationF params := {
   ParamRelFP G VL VL' := VL = VL' length VL = length G;
   ArgRelFP G Z Z' := Z = Z' length Z = length G
}.

Lemma sim_EAE' r L L' V s
  : labenv_sim Sim (sim r) SR (block_Z L) L L'
     L = L'
     sim r Sim (L, V, s) (L',V, compile s).
Proof.
  revert_except s.
  sind s; destruct s; simpl; intros; simpl in × |- ×.
  - destruct e.
    + eapply (sim_let_op il_statetype_F); eauto.
    + eapply (sim_let_call il_statetype_F); eauto.
  - eapply (sim_cond il_statetype_F); eauto.
  - case_eq (omap (op_eval V) (List.filter NotVar Y)); intros.
    + destruct (get_dec L (counted l)) as [[[bE bZ bs n]]|].
      × decide (length Y = length bZ).
        -- eapply sim_expansion_closed;
             [
             | eapply star2_refl
             | eapply list_to_stmt_correct;
               eauto using fresh_spec, fresh_list_unique, fresh_list_spec
             ]; eauto.
           ++ eapply labenv_sim_app; eauto. simpl.
             intros; split; intros; eauto; dcr; subst.
             case_eq (omap (op_eval V) (List.filter IsVar Y)); intros.
             ** exploit (omap_filter_partitions _ _ _ H4 H1).
                intros; repeat cases; eauto.
                 Yv; repeat split; eauto with len.
                erewrite omap_replace_if.
                --- rewrite <- H7; eauto.
                --- erewrite omap_op_eval_agree; eauto.
                    rewrite <- update_with_list_agree';
                      eauto using fresh_spec, fresh_list_unique,
                      fresh_list_spec with len.
                    eapply agree_on_incl.
                    symmetry.
                    eapply update_with_list_agree_minus; eauto.
                    eapply not_incl_minus. reflexivity.
                    symmetry.
                    eapply disj_2_incl.
                    eapply fresh_list_spec; eauto using fresh_spec.
                    eapply list_union_incl; intros; eauto with cset.
                    inv_get. repeat inv_get_step_filter idtac.
                    eapply incl_list_union; eauto using map_get_1.
                --- erewrite omap_op_eval_agree; [ eapply H1 | | ].
                    Focus 2.
                    rewrite omap_lookup_vars; eauto using fresh_list_unique, fresh_spec with len.
                    rewrite <- update_with_list_agree';
                      eauto using fresh_spec, fresh_list_unique,
                      fresh_list_spec with len. reflexivity.
             ** exfalso. eapply omap_filter_none in H4. congruence.
           ++ eapply disj_2_incl.
             eapply fresh_list_spec; eauto using fresh_spec.
             eapply list_union_incl; intros; eauto with cset.
             inv_get. eapply incl_list_union; eauto using map_get_1.
        -- perr.
      × perr.
    + perr.
      erewrite omap_filter_none in def; eauto. congruence.
  - pno_step.
  - pone_step.
    left. eapply IH; eauto 20 with len.
    + rewrite List.map_app.
      eapply labenv_sim_extension_ptw; eauto with len.
      × intros; hnf; intros; inv_get; simpl in *; dcr; subst.
        get_functional. eapply IH; eauto 20 with len.
        rewrite List.map_app. eauto.
      × hnf; intros; simpl in *; subst; inv_get; simpl; eauto.
Qed.

Lemma sim_EAE V s
  : @sim _ statetype_F _ statetype_F bot3 Sim (nil, V, s) (nil,V, compile s).
Proof.
  eapply sim_EAE'; eauto.
  eapply labenv_sim_nil.
Qed.

Lemma list_to_stmt_app_expfree ZL Y Y' l
  : ( n e, get Y' n e isVar e)
     app_expfree (list_to_stmt ZL Y (stmtApp l Y')).
Proof.
  general induction Y; destruct ZL; destruct Y'; simpl;
    econstructor; intros; inv_get; isabsurd; eauto using isVar.
Qed.

Lemma replace_if_get_inv X (p:X bool) L L' n x
  : get (replace_if p L L') n x
     l , get L n l
            ((p l l' n', get L' n' l' x = l')
               (¬ p l x = l)).
Proof.
  intros; general induction L; destruct L'; simpl in *; isabsurd.
  - cases in H; isabsurd. inv H.
    + x; split; eauto using get.
      right; split; eauto. rewrite <- Heq. eauto.
    + edestruct IHL; eauto; dcr.
      eexists; split; eauto using get.
  - cases in H; isabsurd.
    + inv H.
      × eexists; split; eauto using get.
        left. rewrite <- Heq; eauto using get.
      × edestruct IHL; eauto using get; dcr.
        eexists; split; eauto using get.
        destruct H2; dcr; isabsurd. left; eauto 20 using get.
        right; eauto using get.
    + inv H.
      × eexists; split; eauto using get.
        right. rewrite <- Heq; eauto using get.
      × edestruct IHL; eauto; dcr.
        eexists; split; eauto using get.
Qed.

Lemma EAE_app_expfree s
  : app_expfree (compile s).
Proof.
  sind s; destruct s; simpl; eauto using app_expfree.
  - eapply list_to_stmt_app_expfree.
    intros.
    eapply replace_if_get_inv in H; dcr.
    destruct H2; dcr; cases in H0; isabsurd; inv_get; eauto using isVar.
    exfalso; eapply H0; eauto.
    destruct x; eauto using isVar; exfalso; eapply NOTCOND; intro; isabsurd.
  - econstructor; intros; inv_get; eauto using app_expfree.
Qed.