Require Import Arith Lia List.
Import ListNotations.

Require Import ssreflect ssrbool ssrfun.

Set Default Proof Using "Type".
Set Default Goal Selector "!".

Class Enumerable (X: Type) :=
  {
    to_nat : X ;
    of_nat : X;
    enumP {x: X} : of_nat (to_nat x) = x
  }.

Lemma to_nat_inj {X: Type} {enumX: Enumerable X} {x y: X}:
  to_nat x = to_nat y x = y.
Proof.
  constructor; last by move.
  move /(f_equal of_nat). by rewrite ?enumP.
Qed.


Lemma enumarableI {X Y: Type} {enum: Enumerable X} (to_X: Y X) (of_X: X Y) :
  ( (y: Y), of_X (to_X y) = y) Enumerable Y.
Proof.
  move cancel. exists ( y to_nat (to_X y)) ( x of_X (of_nat x)).
  move y. by rewrite enumP cancel.
Qed.


Instance nat_Enumerable : Enumerable .
Proof. by exists id id. Qed.

Instance bool_Enumerable : Enumerable bool.
Proof.
  exists ( b if b then 1 else 0) ( n if n is 0 then false else true).
  by case.
Qed.


Module nat2_Enumerable.

Definition encode '(x, y) : :=
  y + (nat_rec _ 0 ( i m (S i) + m) (y + x)).

Definition decode (n : ) : * :=
  nat_rec _ (0, 0) ( _ '(x, y) if x is S x then (x, S y) else (S y, 0)) n.

Lemma decode_encode {xy: * } : decode (encode xy) = xy.
Proof.
  move Hn: (encode xy) n. elim: n xy Hn.
  { by move [[|?] [|?]]. }
  move n IH [x [|y [H]]] /=.
  - move: x [|x [H]] /=; first done.
    by rewrite (IH (0, x)) /= -?H ?Nat.add_0_r.
  - by rewrite (IH (S x, y)) /= -?H ?Nat.add_succ_r.
Qed.


Lemma encode_non_decreasing (x y: ) : x + y encode (x, y).
Proof. elim: x [| x IH] /=; [| rewrite Nat.add_succ_r /=]; by . Qed.

End nat2_Enumerable.

Instance nat2_Enumerable : Enumerable ( * ).
Proof.
  exists nat2_Enumerable.encode nat2_Enumerable.decode.
  move ?. by apply: nat2_Enumerable.decode_encode.
Qed.


Instance prod_Enumerable {X Y: Type} {enumX: Enumerable X} {enumY: Enumerable Y} : Enumerable (X * Y).
Proof.
  exists
    ( '(x, y) to_nat (to_nat x, to_nat y))
    ( n match of_nat n with | (, ) (of_nat , of_nat ) end).
  move [x y]. by rewrite ?enumP.
Qed.


Instance sum_Enumerable {X Y: Type} {enumX: Enumerable X} {enumY: Enumerable Y} : Enumerable (X + Y).
Proof.
  exists
    ( t match t with | inl x to_nat (0, to_nat x) | inr y to_nat (1, to_nat y) end)
    ( n match of_nat n with | (0, n) inl (of_nat n) | (1, n) inr (of_nat n) | _ inl (of_nat n) end).
  by case ?; rewrite ?enumP.
Qed.


Module list_Enumerable.
Section list_Enumerable_Section.
Variables (X: Type) (enumX: Enumerable X).

Fixpoint encode (L: list X) : :=
  if L is x :: L then 1+nat2_Enumerable.encode (1 + to_nat x, encode L) else 1+nat2_Enumerable.encode (0, 0).

Fixpoint decode (i: ) (n: ) : list X :=
  if i is S i then match nat2_Enumerable.decode (n-1) with | (0, _) [] | (S , ) (of_nat ) :: decode i end else [].

Opaque nat2_Enumerable.encode nat2_Enumerable.decode.

Lemma decode_encode {L: list X} : decode (encode L) (encode L) = L.
Proof.
  suff: i, encode L i decode i (encode L) = L by (apply; ).
  move i. elim: i L; first by (move [|? L] /=; ).
  move i IH [|x L] /= ?; first done.
  rewrite Nat.sub_0_r nat2_Enumerable.decode_encode enumP IH; last done.
  have := nat2_Enumerable.encode_non_decreasing (S (@to_nat X enumX x)) (encode L).
  by .
Qed.


End list_Enumerable_Section.
End list_Enumerable.

Instance list_Enumerable {X: Type} {enumX: Enumerable X} : Enumerable (list X).
Proof.
  exists (list_Enumerable.encode X enumX) ( n list_Enumerable.decode X enumX n n).
  move ?. by apply: list_Enumerable.decode_encode.
Qed.