import Std.Data.TreeMap
import Std.Tactic.Do

/-!
This test is based on code by Bhavik Mehta.
It demonstrates that much of the boilerplate that he was forced to write for his proof
is now automated by the `mvcgen` tactic.
-/

set_option grind.warning false
set_option mvcgen.warning false

section VendoredFromMathlib

abbrev ℕ := Nat

/-- A monad transformer to generate random objects using the generic generator type `g` -/
abbrev RandGT (g : Type) := StateT (ULift g)
/-- A monad to generate random objects using the generator type `g`. -/
abbrev RandG (g : Type) := RandGT g Id

/-- A monad transformer to generate random objects using the generator type `StdGen`.
`RandT m α` should be thought of a random value in `m α`. -/
abbrev RandT := RandGT StdGen

/-- `Random m α` gives us machinery to generate values of type `α` in the monad `m`.

Note that `m` is a parameter as some types may only be sampleable with access to a certain monad. -/
class Random (m) (α : Type u) where
  /-- Sample an element of this type from the provided generator. -/
  random [RandomGen g] : RandGT g m α

/-- `BoundedRandom m α` gives us machinery to generate values of type `α` between certain bounds in
the monad `m`. -/
class BoundedRandom (m) (α : Type u) [LE α] where
  /-- Sample a bounded element of this type from the provided generator. -/
  randomR {g : Type} (lo hi : α) (h : lo ≤ hi) [RandomGen g] : RandGT g m {a // lo ≤ a ∧ a ≤ hi}

namespace Rand
/-- Generate a random `Nat`. -/
def next [RandomGen g] [Monad m] : RandGT g m Nat := do
  let rng := (← get).down
  let (res, new) := RandomGen.next rng
  set (ULift.up new)
  pure res

/-- Create a new random number generator distinct from the one stored in the state. -/
def split {g : Type} [RandomGen g] [Monad m] : RandGT g m g := do
  let rng := (← get).down
  let (r1, r2) := RandomGen.split rng
  set (ULift.up r1)
  pure r2

/-- Get the range of `Nat` that can be generated by the generator `g`. -/
def range {g : Type} [RandomGen g] [Monad m] : RandGT g m (Nat × Nat) := do
  let rng := (← get).down
  pure <| RandomGen.range rng
end Rand

namespace Random

open Rand

variable [Monad m]

/-- Generate a random value of type `α`. -/
def rand (α : Type u) [Random m α] [RandomGen g] : RandGT g m α := Random.random

/-- Generate a random value of type `α` between `x` and `y` inclusive. -/
def randBound (α : Type u)
    [LE α] [BoundedRandom m α] (lo hi : α) (h : lo ≤ hi) [RandomGen g] :
    RandGT g m {a // lo ≤ a ∧ a ≤ hi} :=
  (BoundedRandom.randomR lo hi h : RandGT g _ _)

/-- Generate a random `Fin`. -/
def randFin {n : Nat} [NeZero n] [RandomGen g] : RandGT g m (Fin n) :=
  fun ⟨g⟩ ↦ pure <| randNat g 0 (n - 1) |>.map (Fin.ofNat n) ULift.up

instance {n : Nat} [NeZero n] : Random m (Fin n) where
  random := randFin

instance : BoundedRandom m Nat where
  randomR lo hi h _ := do
    let z ← rand (Fin (hi - lo + 1))
    pure ⟨
      lo + z.val, Nat.le_add_right _ _,
      Nat.add_le_of_le_sub' h (Nat.le_of_lt_add_one z.isLt)
    ⟩

end Random

end VendoredFromMathlib

open Random

/-- Take k samples, without replacement, from [0..n-1] -/
def sampler {m} [Monad m] (n k : ℕ) [NeZero n] (h : k ≤ n) : RandT m (Vector (Fin n) k) := do
  let mut x : Vector (Fin n) k := Vector.replicate _ 0
  let mut h : Std.TreeMap ℕ (Fin n) := Std.TreeMap.empty
  for hi : i in [0:k] do
    let j ← Subtype.val <$> randBound ℕ i (n - 1) (have : i < k := hi.upper; by grind)
    x := x.set i (h.getD j ⟨j, sorry⟩)
    h := h.insert j (h.getD i ⟨i, sorry⟩)
  return x

variable {m : Type → Type u} [Monad m] [LawfulMonad m] {n k : ℕ}


abbrev Midway (n k : ℕ) : Type := MProd (Std.TreeMap ℕ (Fin n)) (Vector (Fin n) k)

def init (n k : ℕ) [NeZero n] : Midway n k :=
  ⟨Std.TreeMap.empty, Vector.replicate _ 0⟩

variable [NeZero n]

def next (data : Midway n k) (i : ℕ) (hi : i < k) (j : ℕ) : Midway n k :=
  ⟨data.fst.insert j (data.fst.getD i ⟨i, sorry⟩), data.snd.set i (data.fst.getD j ⟨j, sorry⟩) hi⟩

structure Midway.valid (data : Midway n k) (i : ℕ) : Prop where
  nodup_take : (data.2.toList.take i).Nodup
  -- disjoint : ∀ j, i ≤ j → j ≤ n - 1 → data.1.getD j j ∉ data.2.toList.take i
  -- injOn : Set.InjOn (fun j ↦ data.1.getD j j) {j | i ≤ j ∧ j ≤ n - 1}

theorem valid_init : Midway.valid (init n k) 0 :=
  sorry -- domain-specific

theorem Midway.valid_next (data : Midway n k) (i : ℕ) (hi : i < k)
    (j : ℕ) (hij : i ≤ j) (hjn : j ≤ n - 1)
    (h : Midway.valid data i) : Midway.valid (next data i hi j) (i + 1) :=
  sorry -- domain-specific

open Std.Do

@[spec]
theorem randFin_total {m : Type → Type u} [Monad m] [WPMonad m ps] {n : ℕ} [NeZero n] :
  ⦃fun _ => P⦄ -- it's unfortunate that we have to "guess" the frame `fun _ => P` ourselves. TODO: autogeneralize based on "parametricity" in `m`?
  randFin (n:=n) (m:=m) (g:=StdGen)
  ⦃⇓ _ _ => P⦄ := by
    unfold randFin
    mintro hs ∀s
    simp [wp]

@[spec]
theorem randBound_spec {m : Type → Type u} [Monad m] [WPMonad m ps] (h : lo ≤ hi)  :
  ⦃fun _ => P⦄
  randBound (m:=m) (g:=StdGen) ℕ lo hi h
  ⦃⇓ _ _ => P⦄ := by
    mvcgen [randBound, BoundedRandom.randomR, rand, random]

theorem sampler_correct {m : Type → Type u} {k h} [Monad m] [WPMonad m ps] :
  ⦃⌜True⌝⦄
  sampler (m:=m) n k h
  ⦃⇓ xs => ⌜xs.toList.Nodup⌝⦄ := by
  mvcgen -leave [sampler]
  case inv1 => exact (⇓ (xs, midway) => ⌜Midway.valid midway xs.prefix.length⌝)
  case vc1 pref cur _ _ _ _ _ _ r _ _ _ =>
    dsimp
    mframe
    rename_i hinv
    mpure_intro
    simp only [List.length_append, List.length_cons, List.length_nil, Nat.zero_add]
    have : cur = pref.length := sorry -- by grind -- wishful thinking :(
    subst this
    apply Midway.valid_next _ pref.length _ r.val r.property.1 r.property.2 hinv
  case vc2 =>
    mpure_intro
    exact valid_init
  case vc3 =>
    dsimp
    mrename_i h
    mpure h
    mpure_intro
    have h := h.nodup_take
    simp at h
    -- prove List.take k r.snd.toList = r.snd.toList for r.snd : Vector (Fin n) k
    sorry
  case vc4 => simp
