From Undecidability.L Require Import Util.L_facts Prelim.StringBase.
From MetaCoq Require Import Template.All Checker.Checker.
Require Import Undecidability.Shared.Libs.PSL.Base.
Require Import String Ascii.

Open Scope string_scope.
Import MonadNotation.

Notation FUEL := 1000.

Definition string_of_int n :=
  match n with
  | 0 => "0"
  | 1 => "1"
  | 2 => "2"
  | 3 => "3"
  | 4 => "4"
  | 5 => "5"
  | 6 => "6"
  | 7 => "7"
  | 8 => "8"
  | 9 => "9"
  | _ => "todo string_of_int"
  end.

Section it_i.
  Variables (X: Type) (f: nat -> X -> X).

  Fixpoint it_i' (i : nat) (n : nat) (x : X) : X :=
    match n with
      | 0 => x
      | S n' => f i (it_i' (S i) n' x)
    end.

  Definition it_i := it_i' 0.
End it_i.

Definition stack {X : Type} (l : list (X -> X)) (x : X) := fold_right (fun f x => f x) x l.


Definition hd {X : Type} (l : list X) : TemplateMonad X :=
  match l with
  | nil => tmFail "hd: empty list"
  | x :: _ => ret x
  end.

Definition tmTypeOf (s : Ast.term) :=
  u <- tmUnquote s ;;
    tmEval hnf (my_projT1 u) >>= tmQuote.

Definition tmTryInfer (n : ident) (red : option reductionStrategy) (A : Type) : TemplateMonad A :=
  r <- tmInferInstance red A ;;
    match r with
    | my_Some i => ret i
    | my_None =>
      A' <- match red with Some red => ret A | None => ret A end;;

         
         
         
         tmPrint "Did not find an instance for ";;
         (tmPrint A');;
         (tmEval cbv ("open obligation " ++ n ++ " for it. You might want to register a instance before and rerun this.") >>= tmPrint);;
         tmLemma n A
    end.

Definition name_of (t : Ast.term) : ident :=
  match t with
    tConst (modp, n) _ => name_after_dot n
  | tConstruct (mkInd (modp, n) _) i _ => "cnstr_" ++ name_after_dot n ++ string_of_int i
  | tInd (mkInd (modp, n) _) _ => "type_" ++ name_after_dot n
  | tVar i => "var_" ++ i
  | _ => "no_name"
  end.


Fixpoint tmIsType (s : Ast.term) : TemplateMonad bool :=
  match s with
  | tInd _ _ => ret true
  | tConst n u => t <- tmTypeOf (tConst n u) ;; match t with tSort _ => ret true | _ => ret false end
  | tVar x => t <- tmTypeOf (tVar x) ;; match t with tSort _ => ret true | _ => ret false end
  | tApp h _ => tmIsType h
  | _ => ret false
  end.

Definition tmNumConstructors (n : kername) : TemplateMonad nat :=
  i <- tmQuoteInductive n ;;
    match ind_bodies i with
      [ i ] => tmEval cbv (| ind_ctors i |)
    | _ => tmFail "Mutual inductive types are not supported"
    end.

Fixpoint argument_types (B : Ast.term) :=
  match B with
  | tProd _ A B => A :: argument_types B
  | _ => []
  end.

Definition split_head_symbol A : option (inductive * list term) :=
  match A with
  | tApp (tInd ind u) R => ret (ind, R)
  | tInd ind u => ret (ind, [])
  | _ => None
  end.

Definition list_constructors (ind : inductive) : TemplateMonad (list (ident * term * nat)) :=
  A <- tmQuoteInductive (inductive_mind ind) ;;
    match ind_bodies A with
    | [ B ] => ret (ind_ctors B)
    | _ => tmFail "error: no mutual inductives supported"
    end.

Definition eq_inductive (hs hs2 : inductive) :=
  match hs, hs2 with
  | mkInd k _, mkInd k2 _ => if kername_eq_dec k k2 then true else false
  end.

Definition tmArgsOfConstructor ind i :=
  A <- tmTypeOf (tConstruct ind i []) ;;
  ret (argument_types A).


Class extracted {A : Type} (a : A) := int_ext : L.term.
Arguments int_ext {_} _ {_}.
Typeclasses Transparent extracted. Hint Extern 0 (extracted _) => progress (cbn [Common.my_projT1]): typeclass_instances.

Class encodable (A : Type) := enc_f : A -> L.term.


MetaCoq Quote Definition tTerm := L.term.

Definition term_mp := MPfile ["L"; "L"; "Undecidability"].
Definition term_kn := (term_mp, "term").

Definition mkLam x := tApp (tConstruct (mkInd term_kn 0) 2 []) [x].
Definition mkVar x := tApp (tConstruct (mkInd term_kn 0) 0 []) [x].
Definition mkApp x y := tApp (tConstruct (mkInd term_kn 0) 1 []) [x; y].

Definition mkAppList s B := fold_left (fun a b => mkApp a b) B s.

MetaCoq Quote Definition mkZero := 0.
MetaCoq Quote Definition mkSucc := S.

Fixpoint mkNat n := match n with
                   | 0 => mkZero
                   | S n => tApp mkSucc [mkNat n]
                   end.


Fixpoint insert_params fuel Params i t :=
  let params := List.length Params in
  match fuel with 0 => tmFail "out of fuel in insert_params" | S fuel =>
  match t with
  | tRel n => (match nth_error Params (params + i - n - 1) with Some x => ret x | _ => ret (tRel n) end)
  | tApp s R => s <- insert_params fuel Params i s ;;
                 R <- monad_map (insert_params fuel Params i) R;;
                 ret (tApp s R)
  | _ => ret t
  end end.

Definition tmGetOption {X} (o : option X) (err : string) : TemplateMonad X :=
  match o with
  | Some x => ret x
  | None => tmFail err
  end.

Definition tmGetMyOption {X} (o : option_instance X) (err : string) : TemplateMonad X :=
  match o with
  | my_Some x => ret x
  | my_None => tmFail err
  end.

Definition mkFixMatch (f x : ident) (t1 t2 : Ast.term) (cases : nat -> list term -> TemplateMonad term) :=
  hs_num <- tmGetOption (split_head_symbol t1) "no head symbol found";;
  let '(ind, Params) := hs_num in
  let params := List.length Params in
  L <- list_constructors ind >>= tmEval hnf ;;
    body <- monad_map_i (fun i '(n, s, args) =>
                          l <- tmArgsOfConstructor ind i ;;
                          l' <- monad_map_i (insert_params FUEL Params) (skipn params l) ;;
                          t <- cases i l' ;; ret (args, t)) L ;;
  ret (Ast.tFix [BasicAst.mkdef
                   Ast.term
                   (nNamed f)
                   (tProd nAnon t1 t2)
                   (tLambda (nNamed x) t1 (tCase (ind, params)
                                                (tLambda nAnon t1 t2)
                                                (tRel 0)
                                                body)) 0] 0).

Existing Instance config.default_checker_flags.

Definition encode_arguments (B : term) (a i : nat) A_j :=
    if eq_term uGraph.init_graph B A_j
    then
      ret (tApp (tRel (S a)) [tRel (a - i -1)])
    else
      A <- tmUnquoteTyped Type A_j ;;
      name <- (tmEval cbv (name_of A_j ++ "_term") >>= tmFreshName) ;;
      E <- tmTryInfer name None (encodable A);;
      t <- tmEval hnf (@enc_f A E);;
      l <- tmQuote t;;
      ret (tApp l [tRel (a - i - 1) ]).

Definition tmInstanceRed name red {X} (x:X) :=
  def' <- tmDefinitionRed name red x;;
  def <- tmQuote def';;
  match def with
    tConst name _ => tmExistingInstance (ConstRef name)
  | _ => tmFail "internal invariant violated : tmInstanceRed"
  end;;
  tmReturn def'.

Definition tmEncode (name : string) (A : Type) :=
  t <- (tmEval hnf A >>= tmQuote) ;;
  hs_num <- tmGetOption (split_head_symbol t) "no inductive";;
  let '(ind, Params) := hs_num in
  num <- tmNumConstructors (inductive_mind ind) ;;
  f <- tmFreshName "encode" ;;
  x <- tmFreshName "x" ;;
  ter <- mkFixMatch f x t tTerm
           (fun i ctr_types =>
              args <- tmEval cbv (|ctr_types|);;
              C <- monad_map_i (encode_arguments t args) ctr_types ;;
              ret (stack (map (tLambda (nAnon)) ctr_types)
                               (it mkLam num ((fun s => mkAppList s C) (mkVar (mkNat (num - i - 1))))))
           ) ;;
  u <- tmUnquoteTyped (encodable A) ter;;
 tmInstanceRed name None u;;
 tmEval hnf u.














Definition gen_constructor args num i :=
  it lam args (it lam num (it_i (fun n s => L.app s #(n + num)) args (var (num - i - 1)))).

Definition extract_constr {A} (a : A) n (i : nat) (t : Ast.term) def' :=
  num <- tmNumConstructors n ;;
      r <- tmEval cbv (gen_constructor (|argument_types t|) num i : extracted a) ;;
      match def' with
      | Some def => def2 <- tmFreshName def ;;
                     tmInstanceRed def2 None r;;tmReturn tt
      | None => tmReturn tt
      end;;
      ret r.

Definition tmExtractConstr' (def : option ident) {A : Type} (a : A) :=
  s <- (tmEval cbv a >>= tmQuote) ;;
  t <- (tmEval hnf A >>= tmQuote) ;;
     match s with
     | Ast.tApp (Ast.tConstruct (mkInd n _) i _) _ =>
         extract_constr a n i t def
     | Ast.tConstruct (mkInd n _) i _ =>
         extract_constr a n i t def
     | _ => tmFail "this is not a constructor"
     end.

Definition tmExtractConstr (def : ident) {A : Type} (a : A) :=
  tmExtractConstr' (Some def) a.










Notation "↑ env" := (fun n => match n with 0 => 0 | S n => S (env n) end) (at level 10).

Fixpoint inferHead' (s:Ast.term) (revArg R: list Ast.term) : TemplateMonad (L.term * list Ast.term) :=
  s'0 <- tmEval cbn (if forallb (fun _ => false) revArg then s else Ast.tApp s (rev revArg));;
  s' <- tmUnquote s'0;;
  s'' <- tmEval cbn (my_projT2 s');;
  res <- tmInferInstance None (extracted (A:=my_projT1 s') s'');;
  match res with
    my_Some s'' => ret (s'',R)
  | my_None =>
    let doSplit := match R with
                    | [] => false
                    | r :: R => if closedn 0 r then true else false
                    end
    in
    match doSplit,R with
      true,r::R => inferHead' s (r::revArg) R
    | _,_ => let lhs := string_of_term s'0 in
             let rhs := string_of_list string_of_term R in
             tmMsg "More readable: initial segment:";;tmPrint s'';;tmMsg "With remainder:";;tmPrint R;;
             tmFail ("Could not extract in inferHead (moreReadable version in *coq*): could not infer any instance for initial segment of " ++lhs ++ " with further arguments "++ rhs)
    end
  end.

Definition inferHead (s:Ast.term) (R:list Ast.term) : TemplateMonad ((L.term + Ast.term) * list Ast.term) :=
  match s with
    Ast.tConst _ _ |
  Ast.tConstruct _ _ _ =>
  res <- inferHead' s [] R;;
      let '(s',R):= res in
      ret (inl s',R)
  | _ => ret (inr s,R)
  end.

Fixpoint extract (env : nat -> nat) (s : Ast.term) (fuel : nat) : TemplateMonad L.term :=
  match fuel with 0 => tmFail "out of fuel" | S fuel =>
  match s with
    Ast.tRel n =>
    t <- tmEval cbv (var (env n));;
                        ret t
  | Ast.tLambda _ _ s =>
    t <- extract ( env) s fuel ;;
      ret (lam t)
  | Ast.tFix [BasicAst.mkdef nm ty s _] _ =>
    t <- extract (fun n => S (env n)) (Ast.tLambda nm ty s) fuel ;;
    ret (rho t)
  | Ast.tApp s R =>
    res <- inferHead s R;;
        let '(res,R') := res in
        
    t <- (match res with
            inl s' => ret s'
          | inr s => extract env s fuel
          end);;
      monad_fold_left (fun t1 s2 => t2 <- extract env s2 fuel ;; ret (L.app t1 t2)) R' t
    
  | Ast.tConst n _ =>
    a <- tmUnquote s ;;
    a' <- tmEval cbn (my_projT2 a);;
    n <- (tmEval cbv (String.append (name_of s) "_term") >>= tmFreshName) ;;
    i <- tmTryInfer n (Some cbn) (extracted a') ;;
      ret (@int_ext _ _ i)

  | Ast.tConstruct (mkInd n _) _ _ =>
    a <- tmUnquote s ;;
    a' <- tmEval cbn (my_projT2 a);;
    nm <- (tmEval cbv (String.append (name_of s) "_term") >>= tmFreshName) ;;
    i <- tmTryInfer nm (Some cbn) (extracted a') ;;
      ret (@int_ext _ _ i)
  | Ast.tCase _ _ s cases =>
    let fix extractCaseEtaExpand (env : nat -> nat) (s:Ast.term) (k:nat) {struct k}: TemplateMonad L.term :=
        match k,s with
          0,_ =>
          t <- extract (fun i => S (env i)) s fuel;;
            ret (lam t)
        | S k,tLambda _ _ s =>
          t <- extractCaseEtaExpand ( env) s k ;;
            ret (lam t)
        | S _, _ => tmFail "Match case does not contain the expected abstractions for bound argument."
        end
    in
    t <- extract env s fuel ;;
      M <- monad_fold_left (fun t1 '(n,s2) => t2 <- extractCaseEtaExpand env s2 n;; ret (L.app t1 t2)) cases t ;;
      ret (L.app M I)
  | Ast.tLetIn _ s1 _ s2 =>
    t1 <- extract env s1 fuel ;;
    t2 <- extract ( env) s2 fuel ;;
    ret( L.app (lam t2) t1)
     
  | Ast.tFix _ _ => tmFail "Mutual Fixpoints are not supported"
  | tVar _ => a <- tmUnquote s ;;
    a' <- tmEval cbn (my_projT2 a);;
    n <- (tmEval cbv (String.append (name_of s) "_term") >>= tmFreshName) ;;
    i <- tmTryInfer n (Some cbn) (extracted a') ;;
      ret (@int_ext _ _ i)
  | tEvar _ _ => tmFail "tEvar is not supported"
  | tSort _ => tmFail "tSort is not supported"
  | tCast _ _ _ => tmFail "tCast is not supported"
  | tProd _ _ _ => tmFail "tProd is not supported"
  | tInd a _ => tmPrint a;;tmFail "tInd is not supported (probably there is a type not in prenex-normal form)"
  | tProj _ _ => tmFail "tProj is not supported"
  | tCoFix _ _ => tmFail "tCoFix is not supported"
  end end.

Fixpoint head_of_const (t : term) :=
  match t with
  | tConst h _ => Some h
  | tApp s _ => head_of_const s
  | _ => None
  end.

Definition tmUnfoldTerm {A}(a:A) :=
  t <- tmQuote a;;
  match head_of_const t with
  | Some h => tmEval (unfold h) a >>=tmQuote
  | _ => ret t
  end.

Polymorphic Definition tmExtract (nm : option string) {A} (a : A) : TemplateMonad (extracted a) :=
  q <- tmUnfoldTerm a ;;
  t <- extract (fun x => x) q FUEL ;;
  match nm with
    Some nm => nm <- tmFreshName nm ;;
              @tmInstanceRed nm None (extracted a) t ;;
              ret t
  | None => ret t
  end.

Opaque extracted.













Global Obligation Tactic := idtac.

Typeclasses Transparent encodable.