From Undecidability.L Require Import L Tactics.Computable Tactics.ComputableTactics Tactics.Extract.
From MetaCoq Require Import Template.All TemplateMonad.Core Template.Ast.
Require Import String List.
Export String.StringSyntax.

Import MonadNotation.
Local Open Scope string_scope.

(* *** Generation of encoding functions *)

Fixpoint mkLApp (s : term) (L : list term) :=
  match L with
  | [] => s
  | t :: L => mkLApp (tApp (tConstruct (mkInd term_kn 0) 1 []) [s; t]) L
  end.

Definition encode_arguments (B : term) (a i : nat) A_j :=
      A <- tmUnquoteTyped Type A_j ;;
      name <- (tmEval cbv (name_of A_j ++ "_term") >>= Core.tmFreshName) ;;
      E <- tmTryInfer name None (registered A);;
      t <- ret (@enc A E);;
      l <- tmQuote t;;
      ret (tApp l [tRel (a - i - 1) ]).

Definition mkMatch (t1 t2 d : Ast.term) (cases : nat -> list term -> Core.TemplateMonad term) :=
  hs_num <- tmGetOption (split_head_symbol t1) "no head symbol found";;
  let '(ind, Params) := hs_num in
  let params := List.length Params in
  L <- list_constructors ind >>= tmEval hnf ;;
    body <- monad_map_i (fun i '(n, s, args) =>
                          l <- tmArgsOfConstructor ind i ;;
                          l' <- monad_map_i (insert_params FUEL Params) (skipn params l) ;;
                          t <- cases i l' ;; ret (args, t)) L ;;
  ret (tCase ((ind, params), Relevant) (tLambda naAnon t1 t2) d
             body).

Definition L_facts_mp := MPfile ["L_facts"; "Util"; "L"; "Undecidability"].

Definition tmMatchCorrect (A : Type) : Core.TemplateMonad Prop :=
  t <- (tmEval hnf A >>= tmQuote) ;;
  hs_num <- tmGetOption (split_head_symbol t) "no inductive";;
  let '(ind, Params) := hs_num in
  num <- tmNumConstructors (inductive_mind ind) ;;
  x <- Core.tmFreshName "x" ;;
  mtch <- mkMatch t (* argument type *) tTerm (* return type *) (tRel (2 * num))
           (fun i (* ctr index *) ctr_types (* ctr type *) =>
              args <- tmEval cbv (|ctr_types|);;
              C <- monad_map_i (encode_arguments t args) ctr_types ;;
              ret (stack (map (tLambda (naAnon)) ctr_types)
                               (((fun s => mkAppList s C) (tRel (args + 2 * (num - i) - 1)))))
           ) ;;
   E' <- Core.tmInferInstance None (registered A);;
   E <- tmGetMyOption E' "failed" ;;
   t' <- ret (@enc A E);;
   l <- tmQuote t';;
   encn <- ret (tApp l [tRel (2*num) ]) ;;
   lhs <- ret (mkLApp encn ((fix f n := match n with 0 => [] | S n => tRel (2 * n + 1) :: f n end ) num)) ;;
   ter <- ret (tProd naAnon t (it (fun s : term => tProd naAnon tTerm (tProd naAnon (tApp (tConst (L_facts_mp, "proc") []) [tRel 0]) s)) num ((tApp (tConst (L_facts_mp, "redLe") []) [mkNat num; lhs; mtch]))));;
   ter <- tmEval cbv ter ;;
   tmUnquoteTyped Prop ter.

Definition matchlem n A := (Core.tmBind (tmMatchCorrect A) (fun m => tmLemma n m ;; ret tt)).

Definition tmGenEncode (n : ident) (A : Type) : TemplateMonad unit :=
  e <- tmEncode n A;;
  modpath <- tmCurrentModPath tt ;;
  e <- tmUnquoteTyped (encodable A) (tConst (modpath, n) []);;
  p <- Core.tmLemma (n ++ "_proc") (forall x : A, proc (@enc_f A e x)) ;;
  n2 <- tmEval cbv ((n ++ "_inj"));;
  i <- Core.tmLemma n2 (injective (@enc_f _ e)) ;;
  n3 <- tmEval cbv ("registered_" ++ n) ;;
  d <- tmInstanceRed n3 None (@mk_registered A e p i);;
  m <- tmMatchCorrect A;;
  n4 <- tmEval cbv (n ++ "_correct") ;;
  (Core.tmBind (tmMatchCorrect A) (fun m => tmLemma n4 m ;; ret tt)).

Arguments tmGenEncode _%string _%type.

(*
Definition tmGenEncode' (n : ident) (A : Type) :=
  e <- tmEncode n A;;
  modpath <- tmCurrentModPath tt ;;
  e <- tmUnquoteTyped (encodable A) (tConst (modpath, n) );;
  p <- Core.tmLemma (n ++ "_proc") (forall x : A, proc (@enc_f A e x)) ;;
  n2 <- tmEval cbv ((n ++ "_inj"));;
  i <- Core.tmLemma n2  (injective (@enc_f _ e)) ;;
  n3 <- tmEval cbv ("registered_" ++ n) ;;
  d <- tmInstanceRed n3 None  (@mk_registered A e p i);;
  m <- tmMatchCorrect A ;; ret tt. *)


(* TODO : use other methode instead, e.g. with typeclasses, as default obligation tactic is very fragile *)
Global Obligation Tactic := try fold (injective (enc_f)); match goal with
                           | [ |- forall x : ?X, proc ?f ] => try register_proc
                           | [ |- injective ?f ] => register_inj
                           | [ |- context [_ >(<= _) _] ] => extract match
                           end || Tactics.program_simpl.