Require Import ssreflect Setoid Lia List.
Require Import PostTheorem.external.Shared.embed_nat.
From PostTheorem Require Import external.mu_nat.
From Coq.Logic Require Import ConstructiveEpsilon.
Import EmbedNatNotations ListNotations.

Local Set Implicit Arguments.
Local Unset Strict Implicit.

Definition monotonic {X} (f : nat -> option X) :=
  forall n1 x, f n1 = Some x -> forall n2, n2 >= n1 -> f n2 = Some x.

Definition agnostic {X} (f : nat -> option X) :=
  forall n1 n2 y1 y2, f n1 = Some y1 -> f n2 = Some y2 -> y1 = y2.

Lemma monotonic_agnostic {X} (f : nat -> option X) :
  monotonic f -> agnostic f.
Proof.
  move => mono n1 n2 y1 y2 H1 H2.
  destruct (Compare_dec.le_ge_dec n1 n2) as [l | g].
  - eapply mono in l; eauto. congruence.
  - eapply mono in g; eauto. congruence.
Qed.

Class partiality :=
  {

    part : Type -> Type ;

    hasvalue : forall A, part A -> A -> Prop ;
    hasvalue_det : forall A x (a1 a2 : A), hasvalue x a1 -> hasvalue x a2 -> a1 = a2 ;

    ret : forall A, A -> part A ;
    ret_hasvalue : forall A (a : A), hasvalue (ret a) a ;

    bind : forall A B, part A -> (A -> part B) -> part B ;
    bind_hasvalue : forall A B x(f : A -> part B) b, hasvalue (bind x f) b <-> exists a, hasvalue x a /\ hasvalue (f a) b;

    undef : forall {A}, part A ;
    undef_hasvalue : forall A (a : A), ~ hasvalue undef a ;

    mu : (nat -> part bool) -> part nat ;
    mu_hasvalue : forall (f : nat -> part bool) n, hasvalue (mu f) n <-> (hasvalue (f n) true /\ forall m, m < n -> hasvalue (f m) false);
  
    seval : forall A, part A -> nat -> option A ;
    seval_mono : forall A x, monotonic (@seval A x) ;
    seval_hasvalue : forall A x (a : A), hasvalue x a <-> exists n, seval x n = Some a ;

  }.

Definition ter {Part : partiality} A (x : part A) := exists a, hasvalue x a.
Definition equiv {Part : partiality} A (x y : part A) := forall a, hasvalue x a <-> hasvalue y a.

Notation "a =! b" := (hasvalue a b) (at level 50).

Notation "A ↛ B" := (A -> part B) (at level 10).

Definition pcomputes {X Y} `{partiality} (f : X Y) (R : X -> Y -> Prop) :=
  forall x y, f x =! y <-> R x y.

Definition functional {X Y} (R : X -> Y -> Prop) :=
  forall x y1 y2, R x y1 -> R x y2 -> y1 = y2.

Definition total {X Y} (R : X -> Y -> Prop) :=
  forall x, exists y, R x y.

Record FunRel X Y := {the_rel :> X -> Y -> Prop ; the_func_proof : functional the_rel}.
Arguments the_rel {_ _}.

Global Instance part_equiv_Equivalence `{partiality} {A} :
  Equivalence (@equiv _ A).
Proof.
  firstorder.
Qed.

Section assume_part.

  Context {impl : partiality}.

  Lemma ret_hasvalue_inv {A} (a1 a2 : A) :
    ret a1 =! a2 -> a1 = a2.
  Proof.
    move => H.
    eapply hasvalue_det. eapply ret_hasvalue. eauto.
  Qed.

  Lemma ret_hasvalue' {A} (a1 a2 : A) :
    a1 = a2 -> ret a1 =! a2.
  Proof.
    intros ->. eapply ret_hasvalue.
  Qed.

  Lemma ret_hasvalue_iff {A} (a1 a2 : A) :
    a1 = a2 <-> ret a1 =! a2.
  Proof.
    split.
    - apply ret_hasvalue'.
    - apply ret_hasvalue_inv.
  Qed.

  Definition mu_tot (f : nat -> bool) := mu (fun n => ret (f n)).

  Lemma mu_tot_hasvalue (f : nat -> bool) n :
    hasvalue (mu_tot f) n <-> (f n = true /\ forall m, m < n -> f m = false).
  Proof.
    unfold mu_tot. rewrite mu_hasvalue. now repeat setoid_rewrite ret_hasvalue_iff.
  Qed.

  Lemma mu_tot_ter (f : nat -> bool) n :
    f n = true ->
    ter (mu_tot f).
  Proof.
    move => H.
    assert (He : exists n, f n = true) by eauto.
    assert (d : forall n, {f n = true} + {~ f n = true}) by (move => n'; destruct (f n'); firstorder congruence).
    destruct (mu_nat_dep _ d He) as [n' Hn'] eqn:E.
    eapply (f_equal (@proj1_sig _ _)) in E.
    exists n'. eapply mu_tot_hasvalue. split.
    - eauto.
    - move => m Hlt. cbn in E. subst.
      eapply mu_nat_dep_min in Hlt. destruct (f m); congruence.
  Qed.

  Definition undef' : forall {A}, A -> part A := fun A a0 => bind (mu_tot (fun _ => false)) (fun n => ret a0).

  Lemma undef'_hasvalue : forall A a0 (a : A), ~ hasvalue (undef' a0) a.
  Proof.
    intros A a0 a [a' [[[=]] % mu_tot_hasvalue H2]] % bind_hasvalue.
  Qed.

  Definition eval' {A} (x : part A) (H : ter x) : {a : A | hasvalue x a}.
  Proof.
    assert (Hx : exists n, seval x n <> None). {
      destruct H as [a [n H] % seval_hasvalue]. exists n. congruence.
    }
    eapply constructive_indefinite_ground_description_nat in Hx as [n Hx].
    - destruct seval eqn:E; try congruence. exists a. eapply seval_hasvalue. firstorder.
    - move => n. destruct seval; firstorder congruence.
  Qed.

  Definition eval {A} (x : part A) (H : ter x) : A := proj1_sig (eval' H).
  Definition eval_hasvalue {A} (x : part A) (H : ter x) : hasvalue x (eval H) := proj2_sig (eval' H).

  Definition mkpart {A} (f : nat -> option A) :=
    bind (mu_tot (fun n => match f n with Some a => true | None => false end))
      (fun n => match f n with Some a => ret a | None => undef end).

  Lemma mkpart_hasvalue1 {A} (f : nat -> option A) :
    forall a, mkpart f =! a -> exists n, f n = Some a.
  Proof.
    move => a.
    rewrite /mkpart bind_hasvalue.
    move => [] x [] / mu_hasvalue [] ter_mu Hmu Hf.
    exists x. destruct (f x). eapply (hasvalue_det (ret_hasvalue a0)) in Hf as ->.
    reflexivity. eapply undef_hasvalue in Hf as [].
  Qed.

  Lemma mkpart_ter {A} (f : nat -> option A) n a :
    f n = Some a -> ter (mkpart f).
  Proof.
    move => Hn. unfold ter.
    rewrite /mkpart. setoid_rewrite bind_hasvalue.
    assert (Hf : exists n, f n <> None). exists n. firstorder congruence.
    assert (d : forall n : nat, {(fun n0 : nat => f n0 <> None) n} + {~ (fun n0 : nat => f n0 <> None) n}). { move => n0. destruct (f n0); firstorder congruence. }
    edestruct (mu_nat_dep _ d Hf) as [m Hm] eqn:E. eapply (f_equal (@proj1_sig _ _)) in E. cbn in E.
    destruct (f m) as [a0|]eqn:E2; try congruence.
    exists a0, m.
    rewrite mu_tot_hasvalue. repeat split.
    + rewrite E2. reflexivity.
    + subst. move => m' Hle.
      destruct (f m') eqn:E3.
      * eapply mu_nat_dep_min in Hle. firstorder congruence.
      * reflexivity.
    + rewrite E2. eapply ret_hasvalue.
  Qed.

  Lemma mkpart_hasvalue2 {A} (f : nat -> option A) n a :
    agnostic f -> f n = Some a -> mkpart f =! a.
  Proof.
    move => Hagn Hn.
    destruct (mkpart_ter Hn) as [a' H].
    destruct (mkpart_hasvalue1 H) as [n' H'].
    now assert (a' = a) as -> by (eapply Hagn; eauto).
  Qed.

  Lemma mkpart_hasvalue {A} (f : nat -> option A) :
    agnostic f -> forall a, mkpart f =! a <-> exists n, f n = Some a.
  Proof.
    split.
    eapply mkpart_hasvalue1.
    move => [n Hn]. eapply mkpart_hasvalue2; eauto.
  Qed.

  Definition par : forall A B, part A -> part B -> part (A + B) :=
    fun A B x y =>
    bind (mu_tot (fun n => if seval x n is Some a then true else if seval y n is Some b then true else false))
      (fun n => if seval x n is Some a then ret (inl a) else if seval y n is Some b then ret (inr b) else undef).

  Lemma par_hasvalue1 : forall A B (x : part A) (y : part B) a, hasvalue (par x y) (inl a) -> hasvalue x a.
  Proof.
    intros A B x y a [a' [(H1 & H2) % mu_tot_hasvalue H3]] % bind_hasvalue.
    destruct (seval x a') eqn:E1, (seval y a') eqn:E2; try congruence.
    - eapply ret_hasvalue_iff in H3 as [= ->]. eapply seval_hasvalue. eauto.
    - eapply ret_hasvalue_iff in H3 as [= ->]. eapply seval_hasvalue. eauto.
    - eapply ret_hasvalue_iff in H3 as [= ->].
  Qed.

  Lemma par_hasvalue2 : forall A B (x : part A) (y : part B) b, hasvalue (par x y) (inr b) -> hasvalue y b.
  Proof.
    intros A B x y b [a' [(H1 & H2) % mu_tot_hasvalue H3]] % bind_hasvalue.
    destruct (seval x a') eqn:E1, (seval y a') eqn:E2; try congruence.
    - eapply ret_hasvalue_iff in H3 as [= ->].
    - eapply ret_hasvalue_iff in H3 as [= ->].
    - eapply ret_hasvalue_iff in H3 as [= ->]. eapply seval_hasvalue. eauto.
  Qed.

  Lemma par_hasvalue3 : forall A B (x : part A) (y : part B), ter x \/ ter y -> ter (par x y).
  Proof.
    intros A B x y [[a H] | [b H]].
    - eapply seval_hasvalue in H as [n Hn].
      destruct (@mu_tot_ter (fun n => if seval x n is Some a then true else if seval y n is Some b then true else false) n) as [m Hm].
      + now rewrite Hn.
      + pose proof (Hm' := Hm).
        eapply mu_tot_hasvalue in Hm as (H1 & H2).
        destruct (seval x m) eqn:E1, (seval y m) eqn:E2; try congruence.
        * exists (inl a0). eapply bind_hasvalue. eexists. split. eapply Hm'. rewrite E1. eapply ret_hasvalue.
        * exists (inl a0). eapply bind_hasvalue. eexists. split. eapply Hm'. rewrite E1. eapply ret_hasvalue.
        * exists (inr b). eapply bind_hasvalue. eexists. split. eapply Hm'. rewrite E1 E2. eapply ret_hasvalue.
    - eapply seval_hasvalue in H as [n Hn].
      destruct (@mu_tot_ter (fun n => if seval x n is Some a then true else if seval y n is Some b then true else false) n) as [m Hm].
      + rewrite Hn. clear. now destruct seval.
      + pose proof (Hm' := Hm).
        eapply mu_tot_hasvalue in Hm as (H1 & H2).
        destruct (seval x m) eqn:E1, (seval y m) eqn:E2; try congruence.
        * exists (inl a). eapply bind_hasvalue. eexists. split. eapply Hm'. rewrite E1. eapply ret_hasvalue.
        * exists (inl a). eapply bind_hasvalue. eexists. split. eapply Hm'. rewrite E1. eapply ret_hasvalue.
        * exists (inr b). eapply bind_hasvalue. eexists. split. eapply Hm'. rewrite E1 E2.
          erewrite (monotonic_agnostic (@seval_mono _ _ _) Hn E2). eapply ret_hasvalue.
  Qed.

End assume_part.