Half-completed crypto experiments in Haskell.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- Needed for nested type family application in `Exp`.
{-# LANGUAGE UndecidableInstances #-}
-- __FIXME__: These are required because of two things. The first is
--           "Data.Bits". We should replace this with a statically-sized
--            alternative @StaticBits (n :: `Nat`) a@ with indexed lookups,
--            etc. It might still be unsafe for primitive types, though. The
--            second is "Data.Vector.Generic". This is currently an efficient
--            representation of arbitrary-length bit strings. Hopefully we can
--            find an alternative for this, too.
{-# OPTIONS_GHC -Wno-missing-safe-haskell-mode -Wno-unsafe #-}
-- Because of @`Bits` `Vec`@ instances.
{-# OPTIONS_GHC -Wno-orphans #-}

-- | This is an implementation of [Keccak](https://keccak.team/keccak.html) with
--   a lot of static guarantees.
--
--   There are some deviations from the spec:
-- * `keccak_f`, rather than taking the value of @b@ takes the value of @l@, so
--   where the spec says "Keccak-/f/[1600]", we would write
--   @`keccak_f` (Proxy \@(`Nat.fromGHC` 6))@.
-- * `keccak_rc` and `keccak_c` each take an extra `Proxy` (@l@ and @r@,
--   respectively), but these are uniquely determined by the other parameters,
--   so it's a minor annoyance.
module Sponge
  ( sponge,
    Block,
    W,
    WBits,
    State,
    B,
  )
where

import Control.Applicative (Applicative (..))
import Control.Arrow (Arrow (..))
import Control.Category (Category (..))
import Data.Bit (Bit (..), Vector)
import Data.Bitraversable (Bitraversable (..))
import Data.Bits (Bits (..), FiniteBits (..))
import Data.Bool (Bool (..))
import Data.Foldable (Foldable (..))
import Data.Functor (Functor (..), (<$>))
import Data.Maybe (Maybe)
import Data.Proxy (Proxy (..))
import Data.Semigroup (Semigroup (..))
import Data.Tuple (uncurry)
import Data.Type.Nat (Nat (..))
import Data.Type.Nat qualified as Nat
import Data.Type.Nat.LE qualified as Nat
import Data.Vec.Lazy (Vec (..))
import Data.Vec.Lazy qualified as Vec
import Data.Vector.Generic qualified as V
import Yaya.Fold (Corecursive (..), Projectable (..), Steppable (..))
import Yaya.Zoo (Stream)
import Prelude (Num (..), fromIntegral, undefined, ($))

-- | This allows us to do bit operations at whatever level of the arrays that we
--   want.
instance (Nat.SNatI n, Bits a) => Bits (Vec n a) where
  (.&.) = Vec.zipWith (.&.)
  (.|.) = Vec.zipWith (.|.)
  xor = Vec.zipWith xor
  complement = fmap complement
  bitSizeMaybe _ = (Nat.reflectToNum (Proxy @n) *) <$> bitSizeMaybe @a undefined
  {-# INLINEABLE bitSizeMaybe #-}

instance (Nat.SNatI n, FiniteBits a) => FiniteBits (Vec n a) where
  finiteBitSize _ = Nat.reflectToNum (Proxy @n) * finiteBitSize @a undefined

type family Exp (n :: Nat) (m :: Nat) :: Nat where
  Exp 'Z ('S _) = 'Z
  Exp ('S _) 'Z = 'S 'Z
  Exp n ('S 'Z) = n -- avoids an extra @Nat.Mult n 1@ at the end
  Exp n ('S m) = Nat.Mult n (Exp n m)

-- = Keccak-specific definitions

-- | The number of bits in a row or column.
type Block = Nat.FromGHC 5

-- | The number of bits in a lane.
type W l = Exp (Nat.FromGHC 2) l

-- | A structure containing @`W` l@ bits.
--
--  __TODO__: Replace this with the more efficient type family below (once we
--            switch to using bitwise operations on @`Vec n `Bool`@).
type WBits (l :: Nat) = Vec (W l) Bool

-- type family WBits (l :: Nat) where
--   WBits 'Z = Bool
--   -- | This could use better types for the 2- and 4-bit cases.
--   WBits ('S 'Z) = Vec (Nat.FromGHC 2) Bool
--   WBits ('S ('S 'Z)) = Vec (Nat.FromGHC 4) Bool
--   WBits ('S ('S ('S 'Z))) = Word8
--   WBits ('S ('S ('S ('S 'Z)))) = Word16
--   WBits ('S ('S ('S ('S ('S 'Z))))) = Word32
--   WBits ('S ('S ('S ('S ('S ('S 'Z)))))) = Word64

type State l = Vec Block (Vec Block (WBits l))

-- | The width of the permutation.
type B l = Nat.Mult Block (Nat.Mult Block (W l))

restructureVec ::
  (Nat.SNatI x, Nat.SNatI y, Nat.SNatI (W l), Nat.SNatI (Nat.Mult y (W l))) =>
  Proxy l ->
  Vec (Nat.Mult x (Nat.Mult y (W l))) Bool ->
  Vec x (Vec y (WBits l))
restructureVec Proxy = fmap Vec.chunks . Vec.chunks

unstructureVec ::
  Proxy l ->
  Vec x (Vec y (WBits l)) ->
  Vec (Nat.Mult x (Nat.Mult y (W l))) Bool
unstructureVec Proxy = Vec.concat . fmap Vec.concat

data ComposedVec w n a = ComposedVec {getVec :: (Vec n a), getWritten :: w}

-- | Collects exactly @n@ elements from the an infinite stream.
takeFromStream :: Nat.SNatI n => Stream a -> Vec n a
takeFromStream stream =
  getVec $
    Nat.induction1
      (ComposedVec VNil stream)
      (\(ComposedVec v s) -> uncurry ComposedVec . first (::: v) $ project s)

-- |
--
--  __FIXME__: Get rid of this function. Currently it breaks if we just call
--            `takeFromStream`.
takeEnough ::
  Nat.SNatI n => Stream (Vec ('S r) Bool) -> Vec ('S n) (Vec ('S r) Bool)
takeEnough = takeFromStream

-- |
--
--  __NB__: In [/Cryptographic sponge
--          functions/](https://keccak.team/files/CSF-0.1.pdf#page=14), the
--          sponge construction relies on @b@. However, it never includes it as
--          a parameter. Here, because @b = r + c@ and because subtraction is
--          difficult, we replace @b - r@ with @c@.
--
--  __FIXME__: The constraints around @n@ here trigger a cascade of annoyances,
--             where the caller eventually needs to explicitly provide the value
--             satisfying @⌈l' / r⌉@.
--
--  __FIXME__: This returns in `Maybe` because @cut@ can’t tell that the length
--             of the input `Vector` is exactly some multiple of @r@. We should
--             change how padding works to provide this guarantee at the type
--             level. Also, a result of `Nothing` would indicate a bug in this
--             implementation. Which means the failure should at least be more
--             informative, but possibly even an exception. However, returning
--            `Maybe` is a good reminder to fix this, as it makes the
--             shortcoming apparent to all callers.
sponge ::
  forall l' l r c n.
  ( Nat.SNatI n,
    Nat.SNatI (W l),
    Bits (WBits l),
    Nat.SNatI r,
    Nat.SNatI c,
    Nat.SNatI (Nat.Mult Block (W l)),
    B l ~ Nat.Plus ('S r) ('S c),
    Nat.LE ('S r) (B l),
    Nat.LE ('S l') (Nat.Mult ('S n) ('S r))
  ) =>
  Proxy l ->
  (State l -> State l) ->
  (Proxy ('S r) -> Nat -> Vector Bit) ->
  Proxy ('S r) ->
  Proxy ('S c) ->
  -- | The optimal value here is  ⌈l' / r⌉, but any value /at least/ that high
  --   will suffice.
  Proxy ('S n) ->
  Vector Bit ->
  Maybe (Vec ('S l') Bool)
sponge l f pad r Proxy Proxy =
  fmap (Vec.take . Vec.concat . takeEnough @n . squeeze . absorb)
    . cut
    . uncurry (<>)
    . (id &&& pad r . fromIntegral . V.length)
  where
    cut :: Vector Bit -> Maybe [Vec ('S r) Bool]
    cut p =
      if V.null p
        then pure []
        else
          fmap (uncurry (:))
            . bitraverse
              (Vec.fromList . V.toList . V.map unBit)
              cut
            $ V.splitAt (Nat.reflectToNum r) p
    absorb :: [Vec ('S r) Bool] -> State l
    absorb =
      foldr
        ( \p_i s ->
            f (s `xor` restructureVec l (p_i Vec.++ pure @(Vec ('S c)) False))
        )
        . pure
        $ pure zeroBits
    squeeze :: State l -> Stream (Vec ('S r) Bool)
    squeeze =
      embed
        . ( Vec.take @('S r) . unstructureVec l
              &&& ana ((Vec.take @('S r) . unstructureVec l &&& id) . f)
          )