{-# LANGUAGE DeriveGeneric #-}

module GardGround.Syntax.Abstract (
  NeutralRoot(..),
  Neutral(..),
  ValueError(..),
  Value(..),
  Expr(..),
  vapp,
  quote,
  unify,
  eval,
  typeInf,
  typeChk,
) where

import Control.Monad (foldM)
import Generic.Data (Generic)
import Numeric.Natural (Natural)

import GardGround.Syntax.Literal (Literal(..), PrimTy(..), lookupBinder, litTypeOf)
import GardGround.Utils.SteParser.Lex (Ident)

data NeutralRoot = NrLiteral Literal
  | Global Ident
  | Local Natural
  | Quote Natural
  -- note that Local and Quote are numbered in opposite directions
  deriving (Eq, Generic)

instance Show NeutralRoot where
    show (NrLiteral lit) = show lit
    show (Global i) = "Global:" <> show i
    show (Local i) = "Local:" <> show i
    show (Quote i) = "Quote:" <> show i

-- | neutral values (a root, to which might be some values applied)
--   inner (later) values are applied first, further applications get appended to front
data Neutral = Neutral NeutralRoot [Value]

data ValueError =
    VEInvalidApply Value Value
  | VEInvalidLamTy Value
  | VEOutOfRangeQuote Natural
  | VEOutOfRangeLocal Natural
  | VEInferGlobal Ident
  | VEInferLam Expr
  | VEUnifyFailed Value Value

data Value = VLam (Value -> Either ValueError Value)
  | VPi Value (Value -> Either ValueError Value)
  | VNeutral Neutral

-- base design based upon LambdaPi (https://www.andres-loeh.de/LambdaPi/LambdaPi.pdf)
-- and with advanced techniques from smalltt (https://github.com/AndrasKovacs/smalltt)

type EvalResult = Either ValueError Value

-- | Apply a value to a value
vapp :: Value -> Value -> EvalResult
vapp (VLam f) = f
vapp (VPi _ f) = f -- ^ this is mostly for convenience
vapp (VNeutral (Neutral (NrLiteral lit) xs)) = Left . VEInvalidApply (VNeutral $ Neutral (NrLiteral lit) xs)
vapp (VNeutral (Neutral nam xs)) = \v -> Right . VNeutral . Neutral nam $ v:xs
vapp l = Left . VEInvalidApply l

vfree :: NeutralRoot -> Value
vfree r = VNeutral $ Neutral r mempty

data Expr =
    ENeutralRoot NeutralRoot
  | EApply Expr Expr
  | ELam Expr
  | EPi Expr Expr
  | EAnnot Expr Expr

type QuoteResult = Either ValueError Expr

neutralQuote :: Natural -> Neutral -> QuoteResult
neutralQuote i (Neutral nam xs) =
  let
    boundfree = case nam of
      Quote j ->
        let jp1 = j + 1 in
          if i < jp1
            then Left . VEOutOfRangeQuote $ jp1 - i
            else Right . Local $ i - jp1
      _ -> Right nam
  in
    boundfree >>= \bf -> foldM (\a b -> EApply a <$> quote i b) (ENeutralRoot bf) (reverse xs)

-- quote0 = quote 0

quote :: Natural -> Value -> QuoteResult
quote i (VLam f) = ELam <$> ((f . vfree $ Quote i) >>= quote (i + 1))
quote i (VNeutral neut) = neutralQuote i neut
quote i (VPi arg f) = do
  arge <- quote i arg
  fe <- (f . vfree $ Quote i) >>= quote (i + 1)
  pure (EPi arge fe)

eval :: Expr -> [Value] -> EvalResult
eval kind env = case kind of
  ENeutralRoot (Local x) -> case lookupBinder env x of
    Nothing -> Left . VEOutOfRangeLocal $ x - (toEnum $ length env)
    Just y -> Right y
  ENeutralRoot x -> Right $ VNeutral (Neutral x mempty)
  EAnnot e _ -> eval e env

  EApply e e' -> do
    v  <- eval e  env
    v' <- eval e' env
    vapp v v'

  ELam e -> Right . VLam $ \y -> eval e (y:env)
  EPi arg e -> fmap (\res -> VPi res $ \y -> eval e (y:env)) $ eval arg env

unify :: Natural -> Value -> Value -> Either ValueError ()
unify i l r =
  let vf = vfree $ Quote i in
  case (l, r) of
    (VLam f, VLam f') -> do
      v  <- f  vf
      v' <- f' vf
      unify (i + 1) v v'
    (VPi a f, VPi a' f') -> do
      unify i a a'
      v  <- f  vf
      v' <- f' vf
      unify (i + 1) v v'
    (VNeutral (Neutral nr xs), VNeutral (Neutral nr' xs')) -> do
      lres <- unifyList xs xs'
      if (nr /= nr') || (not lres) then (Left $ VEUnifyFailed (VNeutral $ Neutral nr xs) (VNeutral $ Neutral nr' xs')) else pure ()
    _ -> Left $ VEUnifyFailed l r

  where
    unifyList :: [Value] -> [Value] -> Either ValueError Bool
    unifyList [] [] = Right True
    unifyList [] (_:_) = Right False
    unifyList (_:_) [] = Right False
    unifyList (x:xs) (y:ys) = unify i x y >> unifyList xs ys

evalTenv :: [(Value, Value)] -> Expr -> EvalResult
evalTenv tenv e = eval e $ fmap (\(x, _) -> x) tenv
{-# INLINE evalTenv #-}

tenvLen :: [(Value, Value)] -> Natural
tenvLen tenv = toEnum $ length tenv
{-# INLINE tenvLen #-}

-- | Try to infer the type of an expression with a given context of (values, types) on stack
typeInf :: [(Value, Value)] -> Expr -> EvalResult
typeInf tenv kind = case kind of
  ENeutralRoot (Local x) -> case lookupBinder tenv x of
    Nothing -> Left . VEOutOfRangeLocal $ x - (tenvLen tenv)
    Just (_, y) -> Right y
  ENeutralRoot (NrLiteral lit) -> Right . VNeutral $ Neutral (NrLiteral . LPrimTy $ litTypeOf lit) mempty
  ENeutralRoot (Global g) -> Left $ VEInferGlobal g
  ENeutralRoot (Quote q) -> Left $ VEOutOfRangeQuote q

  EAnnot e e' -> do
    v' <- evalTenv tenv e'
    typeChk tenv e v'
    Right v'

  EApply (ELam e) e' -> do
    t' <- typeInf tenv e'
    v' <- evalTenv tenv e'
    -- here, we don't have to unify the argument
    typeInf ((v', t'):tenv) e

  EApply e e' -> do
    t  <- typeInf tenv e
    case t of
      VPi arg resf -> do
        t' <- typeInf tenv e'
        unify 0 arg t'
        v' <- evalTenv tenv e'
        resf v'
      _ -> Left $ VEInvalidLamTy t

  EPi e e' -> do
    let vstar = VNeutral $ Neutral (NrLiteral $ LPrimTy PtType) mempty
    typeChk tenv e vstar
    v  <- evalTenv tenv  e
    let tenv' = (VNeutral $ Neutral (Local $ tenvLen tenv) mempty, v):tenv
    v' <- evalTenv tenv' e'
    q' <- quote 0 v'
    typeChk tenv' q' vstar
    pure vstar

  ELam f -> Left $ VEInferLam f

-- | Check if an expression matches the expected type
typeChk :: [(Value, Value)] -> Expr -> Value -> Either ValueError ()
typeChk tenv kind xty = case (kind, xty) of
  (ELam f, VPi arg res) -> do
    let llte = VNeutral $ Neutral (Local $ tenvLen tenv) mempty
    let tenv' = (llte, arg):tenv
    v <- evalTenv tenv' f
    rv <- res llte
    q <- quote 0 v
    typeChk tenv' q rv

  _ -> do
    k' <- typeInf tenv kind
    unify 0 k' xty