
Require Import Util EqDec LengthEq Get Coq.Classes.RelationClasses MoreList AllInRel ListUpdateAt.
Require Import SizeInduction Infra.Lattice OptionR DecSolve.
Require Import Annotation AnnotationLattice.

Set Implicit Arguments.

Inductive anni (A:Type) : Type :=
| anni0 : anni A
| anni1 (a1:A) : anni A
| anni2 (a1:A) (a2:A) : anni A
| anni2opt (a1:option A) (a2:option A) : anni A.

Arguments anni0 [A].

Fixpoint setAnni {A} (a:ann A) (ai:anni A) : ann A :=
  match a, ai with
    | ann1 a an, anni1 a1ann1 a (setTopAnn an a1)
    | ann2 a an1 an2, anni2 a1 a2ann2 a (setTopAnn an1 a1) (setTopAnn an2 a2)
    | an, _an

Definition mapAnni {A B} (f:A B) (ai:anni A) : anni B :=
  match ai with
    | anni0anni0
    | anni1 a1anni1 (f a1)
    | anni2 a1 a2anni2 (f a1) (f a2)
    | anni2opt o1 o2anni2opt (mapOption f o1) (mapOption f o2)

Inductive anni_R A B (R : A B Prop) : anni A anni B Prop :=
| anni_R0
  : anni_R R anni0 anni0
| anni_R1 a1 a2
  : R a1 a2 anni_R R (anni1 a1) (anni1 a2)
| anni_R2 a1 a1' a2 a2'
  : R a1 a2 R a1' a2' anni_R R (anni2 a1 a1') (anni2 a2 a2')
| anni_R2o o1 o1' o2 o2'
  : option_R R o1 o2 option_R R o1' o2' anni_R R (anni2opt o1 o1') (anni2opt o2 o2').

Instance anni_R_refl {A} {R} `{Reflexive A R} : Reflexive (anni_R R).
  hnf; intros; destruct x; eauto using anni_R, option_R.
  econstructor; reflexivity.

Instance anni_R_sym {A} {R} `{Symmetric A R} : Symmetric (anni_R R).
  hnf; intros. inv H0; eauto using anni_R.
  econstructor; symmetry; eauto.

Instance anni_R_trans A R `{Transitive A R} : Transitive (anni_R R).
  hnf; intros ? ? ? B C.
  inv B; inv C; econstructor; eauto.
  - etransitivity; eauto.
  - etransitivity; eauto.

Instance anni_R_equivalence A R `{Equivalence A R} : Equivalence (anni_R R).
  econstructor; eauto with typeclass_instances.

Instance anni_R_anti A R Eq `{EqA:Equivalence _ Eq} `{@Antisymmetric A Eq EqA R}
  : @Antisymmetric _ (anni_R Eq) _ (anni_R R).
  intros ? ? B C. inv B; inv C; eauto using anni_R.
  econstructor; eapply option_R_anti; eauto.

Instance anni_R_dec A B (R:ABProp)
         `{ a b, Computable (R a b)} (a:anni A) (b:anni B) :
  Computable (anni_R R a b).
  destruct a,b; try dec_solve.
  - decide (R a1 a0); dec_solve.
  - decide (R a1 a0); decide (R a2 a3); dec_solve.
  - decide (option_R R a1 a0); decide (option_R R a2 a3); dec_solve.

Instance PartialOrder_anni Dom `{PartialOrder Dom}
: PartialOrder (anni Dom) := {
  poLe := anni_R poLe;
  poLe_dec := @anni_R_dec _ _ poLe poLe_dec;
  poEq := anni_R poEq;
  poEq_dec := @anni_R_dec _ _ poEq poEq_dec;
  - intros. inv H0; eauto 20 using @anni_R, poLe_refl.
    econstructor; eapply (@poLe_refl _ (PartialOrder_option Dom)); eauto.

Definition getAnni A (a:A) (an:anni A) :=
  match an with
  | anni1 aa
  | _a

Lemma poLe_getAnni A `{PartialOrder A} (a a':A) an an'
  : poLe a a'
     poLe an an'
     poLe (getAnni a an) (getAnni a' an').
  intros LE LE'; inv LE'; simpl; eauto.

Lemma poEq_getAnni A `{PartialOrder A} (a a':A) an an'
  : poEq a a'
     poEq an an'
     poEq (getAnni a an) (getAnni a' an').
  intros LE LE'; inv LE'; simpl; eauto.

Hint Resolve poLe_getAnni poEq_getAnni.

Definition getAnniLeft A (a:A) (an:anni A) :=
  match an with
  | anni2 a _a
  | _a

Lemma poLe_getAnniLeft A `{PartialOrder A} (a a':A) an an'
  : poLe a a'
     poLe an an'
     poLe (getAnniLeft a an) (getAnniLeft a' an').
  intros LE LE'; inv LE'; simpl; eauto.

Lemma poEq_getAnniLeft A `{PartialOrder A} (a a':A) an an'
  : poEq a a'
     poEq an an'
     poEq (getAnniLeft a an) (getAnniLeft a' an').
  intros LE LE'; inv LE'; simpl; eauto.

Hint Resolve poLe_getAnniLeft poEq_getAnniLeft.

Definition getAnniRight A (a:A) (an:anni A) :=
  match an with
  | anni2 _ aa
  | _a

Lemma poLe_getAnniRight A `{PartialOrder A} (a a':A) an an'
  : poLe a a'
     poLe an an'
     poLe (getAnniRight a an) (getAnniRight a' an').
  intros LE LE'; inv LE'; simpl; eauto.

Lemma poEq_getAnniRight A `{PartialOrder A} (a a':A) an an'
  : poEq a a'
     poEq an an'
     poEq (getAnniRight a an) (getAnniRight a' an').
  intros LE LE'; inv LE'; simpl; eauto.

Hint Resolve poLe_getAnniRight poEq_getAnniRight.

Lemma ann_bottom s' (Dom:Type) `{LowerBounded Dom}
  : (d : ann Dom), annotation s' d setAnn bottom s' d.
  sind s'; destruct s'; simpl; intros d' Ann; inv Ann; simpl;
    eauto using bottom_least.
  - econstructor; eauto using bottom_least with len.
    + intros; inv_get. eapply IH; eauto.
    + eapply IH; eauto.

Definition setTopAnnO {A} `{LowerBounded A} a (al:option A) :=
  match al with
  | NonesetTopAnn a bottom
  | Some al'setTopAnn a al'

Lemma setTopAnnO_annotation A `{LowerBounded A} a (al:option A) s
  : annotation s a annotation s (setTopAnnO a al).
  intros. unfold setTopAnnO; cases; eauto using setTopAnn_annotation.

Lemma ann_R_setTopAnnO_poLe (A : Type) `{PartialOrder A} `{LowerBounded A} a b
         (an : ann A) (bn : ann A)
  : poLe a b poLe an bn poLe (setTopAnnO an a) (setTopAnnO bn b).
  intros. unfold setTopAnnO; repeat cases; eauto.
   eapply ann_R_setTopAnn; eauto. eapply bottom_least.

Lemma ann_R_setTopAnnO_poEq (A : Type) `{PartialOrder A} `{LowerBounded A} a b
         (an : ann A) (bn : ann A)
  : poEq a b poEq an bn poEq (setTopAnnO an a) (setTopAnnO bn b).
  intros. unfold setTopAnnO; repeat cases; eapply ann_R_setTopAnn; eauto.

Hint Resolve ann_R_setTopAnnO_poLe ann_R_setTopAnnO_poEq.

Lemma PIR2_ojoin_zip A `{JoinSemiLattice A} (a:list A) a' b b'
  : poLe a a'
     poLe b b'
     poLe (join a b) (join a' b').
  intros. hnf in H1,H2. general induction H1; inv H2; simpl; eauto using PIR2.
  econstructor; eauto.
  rewrite pf, pf0. reflexivity. eapply IHPIR2; eauto.

Lemma poEq_join_zip A `{JoinSemiLattice A} (a:list A) a' b b'
  : poEq a a'
     poEq b b'
     poEq (join a b) (join a' b').
  intros. hnf in H1,H2. general induction H1; inv H2; simpl; eauto using PIR2.
  econstructor; eauto.
  rewrite pf, pf0. reflexivity. eapply IHPIR2; eauto.

Hint Resolve PIR2_ojoin_zip poEq_join_zip.

Lemma update_at_poLe A `{LowerBounded A} B (L:list B) n (a:A) b
  : poLe a b
     poLe (list_update_at (tab bottom L) n a)
            (list_update_at (tab bottom L) n b).
  general induction L; simpl; eauto using PIR2.
  - destruct n; simpl; eauto using @PIR2.

Lemma update_at_poEq A `{LowerBounded A} B (L:list B) n (a:A) b
  : poEq a b
     poEq (list_update_at (tab bottom L) n a)
            (list_update_at (tab bottom L) n b).
  general induction L; simpl; eauto using PIR2.
  - destruct n; simpl; eauto using @PIR2.

Hint Resolve update_at_poLe update_at_poEq.

Lemma PIR2_fold_zip_join X `{JoinSemiLattice X} (A A':list (list X)) (B B':list X)
  : poLe A A'
     poLe B B'
     poLe (fold_left (zip join) A B)
           (fold_left (zip join) A' B').
  intros. simpl in ×.
  general induction H1; simpl; eauto.

Lemma PIR2_fold_zip_join_poEq X `{JoinSemiLattice X} (A A':list (list X)) (B B':list X)
  : poEq A A'
     poEq B B'
     poEq (fold_left (zip join) A B)
           (fold_left (zip join) A' B').
  intros. simpl in ×.
  general induction H1; simpl; eauto.

Hint Resolve PIR2_fold_zip_join PIR2_fold_zip_join_poEq.

Lemma tab_false_impb Dom `{PartialOrder Dom} AL AL'
  : poLe AL AL'
     poLe (tab false AL) (tab false AL').
  intros. hnf in H0.
  general induction H0; simpl; unfold impb; eauto.

Lemma update_at_impb Dom `{PartialOrder Dom} AL AL' n
  : poLe AL AL'
     poLe (list_update_at (tab false AL) n true)
            (list_update_at (tab false AL') n true).
  intros A. general induction A; simpl; eauto.
  - destruct n; simpl; eauto using @PIR2, tab_false_impb.

Ltac refold_PIR2_PO :=
  match goal with
  | [ H : context [ PIR2 (@poLe ?D _) ] |- _ ] ⇒
    change (PIR2 (@poLe D _)) with (@poLe (list D) _) in H
  | [ H : context [ PIR2 (@poEq ?D _) ] |- _ ] ⇒
    change (PIR2 (@poLe D _)) with (@poLe (list D) _) in H
  | [ |- context [ PIR2 (@poLe ?D ?PO) ] ] ⇒
    change (PIR2 (@poLe D PO)) with (@poLe (list D) _)
  | [ |- context [ PIR2 (@poEq ?D ?PO) ] ] ⇒
    change (PIR2 (@poEq D PO)) with (@poEq (list D) _)


Smpl Add refold_PIR2_PO : inversion_cleanup.

Ltac refold_ann_PO :=
  match goal with
  | [ H : context [ @ann_R ?A ?A (@poLe ?A ?I) ] |- _ ] ⇒
    change (@ann_R A A (@poLe A I)) with (@poLe (@ann A) _) in H
  | [ |- context [ ann_R poLe ?x ?y ] ] ⇒
    change (ann_R poLe x y) with (poLe x y)
  | [ H : context [ @ann_R ?A ?A (@poEq ?A ?I) ] |- _ ] ⇒
    change (@ann_R A A (@poEq A I)) with (@poEq (@ann A) _) in H
  | [ |- context [ ann_R poEq ?x ?y ] ] ⇒
    change (ann_R poEq x y) with (poEq x y)

Smpl Add refold_ann_PO : inversion_cleanup.

Hint Resolve join_struct join_struct_eq.

Lemma PIR2_poLe_join X `{JoinSemiLattice X} (A A' B B':list X)
  : poLe A A'
     poLe B B'
     poLe (join A B) (join A' B').
  intros AA BB.
  general induction AA; simpl; inv BB; eauto.

Hint Resolve PIR2_poLe_join.

Lemma PIR2_impb_orb A A' B B'
  : PIR2 impb A A'
     PIR2 impb B B'
     PIR2 impb (orb A B) (orb A' B').
  intros. pose proof (@PIR2_poLe_join bool _ _).
  eapply H1; eauto.

Smpl Add 10 match goal with
         | [ H : _ < _ |- _ ] ⇒ simpl in H
         | [ H : _ _ |- _ ] ⇒ simpl in H
         end : inv_trivial.

Smpl Add match goal with
         | [ H : S _ < 0 |- _ ] ⇒ exfalso; inv H
         | [ H : S _ 0 |- _ ] ⇒ exfalso; inv H
         end : inv_trivial.

Hint Resolve join_poLe.

Lemma join_poLe_left X `{JoinSemiLattice X} x y z
  : poLe x y poLe x (join y z).
  intros LE. rewrite LE. eauto.

Lemma join_poLe_right X `{JoinSemiLattice X} x y z
  : poLe x y poLe x (join z y).
  intros LE. rewrite LE. rewrite join_commutative. eauto.

Hint Resolve join_poLe_left join_poLe_right | 50.

Lemma join_poLe_left_inv X `{JoinSemiLattice X} x y z
  : poLe (join y z) x poLe y x.
  intros LE. rewrite <- LE. eauto.

Lemma join_poLe_right_inv X `{JoinSemiLattice X} x y z
  : poLe (join z y) x poLe y x.
  intros LE. rewrite <- LE. eauto.

Hint Resolve le_S_n | 100.

Lemma PIR2_poLe_join_right X `{JoinSemiLattice X} A A' B
  : length A length B
     poLe A A'
     poLe A (join A' B).
  intros LEN AA.
  general induction AA; destruct B; simpl in *; clear_trivial_eqs; eauto.

Lemma PIR2_poLe_join_left X `{JoinSemiLattice X} A A' B
  : length A length B
     poLe A A'
     poLe A (join B A').
  intros LEN AA.
  general induction AA; destruct B; simpl in *; clear_trivial_eqs; eauto.

Hint Resolve PIR2_poLe_join_right PIR2_poLe_join_left | 50.

Smpl Add 50 match goal with
            | [ H : context [ impb ] |- _ ] ⇒
              change impb with (@poLe bool PartialOrder_bool) in H
            | [ |- context [ impb ] ] ⇒
              change impb with (@poLe bool PartialOrder_bool)
         end : inversion_cleanup.

Smpl Add 50 match goal with
            | [ H : context [ orb ] |- _ ] ⇒
              change orb with (@join bool _ bool_joinsemilattice) in H
            | [ |- context [ orb ] ] ⇒
              change orb with (@join bool _ bool_joinsemilattice)
            end : inversion_cleanup.

Lemma poLe_length X `{PartialOrder X} A B
  : poLe A B
     A B.
  intros. hnf in H0. erewrite PIR2_length; eauto.

Lemma poLe_length_eq X `{PartialOrder X} A B
  : poLe A B
     A = B.
  intros. hnf in H0. erewrite PIR2_length; eauto.

Hint Resolve poLe_length : len.
Hint Resolve poLe_length_eq : len.

Instance poLe_length_proper X `{PartialOrder X}
  : Proper (poLe ==> eq) (@length X).
  unfold Proper, respectful; intros.
  eauto with len.

Instance poEq_length_proper X `{PartialOrder X}
  : Proper (poEq ==> eq) (@length X).
  unfold Proper, respectful; intros.
  eauto with len.

Lemma PIR2_impb_fold (A A':list (list bool × bool)) (B B':list bool)
  : poLe A A'
     poLe B B'
     ( n a, get A n a length B length (fst a))
     poLe (fold_left (fun a (b:list bool × bool) ⇒ if snd b then orb a (fst b) else a) A B)
           (fold_left (fun a (b:list bool × bool) ⇒ if snd b then orb a (fst b) else a) A' B').
  general induction H; simpl; inv_cleanup; eauto.
  eapply IHPIR2; eauto using PIR2_impb_orb.
  - exploit H1; eauto using get.
    inv pf.
    repeat cases; eauto.
    eapply PIR2_poLe_join_right; eauto using get.
    rewrite <- H3; eauto.
  - intros. cases; eauto using get.
    rewrite zip_length3; eauto using get.

Lemma PIR2_zip_join_inv_left X `{JoinSemiLattice X} A B C
  : poLe (join A B) C
     length A = length B
     poLe A C.
  intros. length_equify.
  general induction H1; inv H2; simpl in *; clear_trivial_eqs;
    eauto using join_poLe_left_inv.

Lemma PIR2_zip_join_inv_right X `{JoinSemiLattice X} A B C
  : poLe (join A B) C
     length A = length B
     poLe B C.
  general induction H2; inv H1; clear_trivial_eqs; eauto using join_poLe_right_inv.

Lemma PIR2_poLe_zip_join_left X `{JoinSemiLattice X} A B
  : length A = length B
     poLe A (join A B).
  general induction H1; simpl in *; eauto using PIR2; try solve [ congruence ].

Lemma PIR2_zip_join_commutative X `{JoinSemiLattice X} A B
  : poLe (join B A) (join A B).
  general induction A; destruct B; simpl in *; eauto.
  eauto using join_commutative.

Lemma PIR2_poLe_zip_join_right X `{JoinSemiLattice X} A B
  : length A = length B
     poLe B (join A B).
  intros. rewrite <- PIR2_zip_join_commutative.   eapply PIR2_poLe_zip_join_left; congruence.

Lemma PIR2_fold_zip_join_inv X `{JoinSemiLattice X} A B C
  : poLe (fold_left (zip join) A B) C
     ( n a, get A n a length a = length B)
     poLe B C.
  general induction A; simpl in *; eauto.
  eapply IHA; eauto using get.
  rewrite <- H1. eauto.
  eapply PIR2_fold_zip_join; eauto.
  eapply PIR2_poLe_zip_join_left.
  symmetry. eauto using get.

Lemma PIR2_fold_zip_join_right X `{JoinSemiLattice X} (A:list X) B C
  : ( n a, get B n a length a = length C)
     poLe A C
     poLe A (fold_left (zip join) B C).
  general induction B; simpl; eauto.
  eapply IHB; intros; eauto using get with len.
  - rewrite zip_length2; eauto using eq_sym, get.
  - rewrite H2. eapply PIR2_poLe_zip_join_left. symmetry. eauto using get.

Lemma PIR2_fold_zip_join_left X `{JoinSemiLattice X} (A:list X) B C a k
  : get B k a
     poLe A a
     ( n a, get B n a length a = length C)
     poLe A (fold_left (zip join) B C).
  general induction B; simpl in *; eauto.
  - inv H1.
    + eapply PIR2_fold_zip_join_right.
      intros. rewrite zip_length2; eauto using eq_sym, get.
      rewrite H2. eapply PIR2_poLe_zip_join_right.
      eauto using eq_sym, get.
    + eapply IHB; eauto using get.
      intros. rewrite zip_length2; eauto using eq_sym, get.

Lemma get_union_union_b X `{JoinSemiLattice X} (A:list (list X)) (b:list X) n x
  : get b n x
     ( n a, get A n a a = b)
     y, get (fold_left (zip join) A b) n y poLe x y.
  intros GETb LEN. general induction A; simpl in ×.
  - eexists x. eauto with cset.
  - exploit LEN; eauto using get.
    edestruct (get_length_eq _ GETb (eq_sym H1)) as [y GETa]; eauto.
    exploit (zip_get join GETb GETa).
    + exploit IHA; try eapply GET; eauto.
      rewrite zip_length2; eauto using get with len.
      edestruct H3; dcr; subst. eexists; split; eauto using join_poLe_left_inv.

Lemma get_fold_zip_join X (f: X X X) (A:list (list X)) (b:list X) n
  : ( n a, get A n a a = b)
     n < b
     y, get (fold_left (zip f) A b) n y.
  intros LEN. general induction A; simpl in ×.
  - edestruct get_in_range; eauto.
  - exploit LEN; eauto using get.
    eapply IHA; eauto using get with len.

Lemma get_union_union_A X `{JoinSemiLattice X} (A:list (list X)) a b n k x
  : get A k a
     get a n x
     ( n a, get A n a a = b)
     y, get (fold_left (zip join) A b) n y poLe x y.
  intros GETA GETa LEN.
  general induction A; simpl in × |- *; isabsurd.
  inv GETA; eauto.
  - exploit LEN; eauto using get.
    edestruct (get_length_eq _ GETa H1) as [y GETb].
    exploit (zip_get join GETb GETa).
    exploit (@get_union_union_b _ _ _ A); eauto.
    rewrite zip_length2; eauto using get with len.
    destruct H3; dcr; subst.
    eexists; split; eauto using join_poLe_right_inv.
  - exploit IHA; eauto.
    rewrite zip_length2; eauto using get.
    symmetry; eauto using get.

Lemma fold_left_zip_orb_inv A b n
  : get (fold_left (zip orb) A b) n true
     get b n true k a, get A k a get a n true.
  intros Get.
  general induction A; simpl in *; eauto.
  edestruct IHA; dcr; eauto 20 using get.
  inv_get. destruct x, x0; isabsurd; eauto using get.

Lemma fold_left_mono A A' b b'
  : poLe A A'
     poLe b b'
     poLe (fold_left (zip orb) A b) (fold_left (zip orb) A' b').
  hnf in H. general induction H; simpl; eauto. inv_cleanup.
  - eapply IHPIR2; eauto.

Lemma fold_list_length A B (f:list B (list A × bool) list B) (a:list (list A × bool)) (b: list B)
  : ( n aa, get a n aa b fst aa)
     ( aa b, b fst aa f b aa = b)
     length (fold_left f a b) = b.
  intros LEN.
  general induction a; simpl; eauto.
  erewrite IHa; eauto 10 using get with len.
  intros. rewrite H; eauto using get.

Lemma fold_list_length' A B (f:list B (list A) list B) (a:list (list A)) (b: list B)
  : ( n aa, get a n aa b aa)
     ( aa b, b aa f b aa = b)
     length (fold_left f a b) = b.
  intros LEN.
  general induction a; simpl; eauto.
  erewrite IHa; eauto 10 using get with len.
  intros. rewrite H; eauto using get.