{-# LANGUAGE ConstraintKinds        #-}
{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE KindSignatures         #-}
{-# LANGUAGE MagicHash              #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE PatternSynonyms        #-}
{-# LANGUAGE PolyKinds              #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE UndecidableInstances   #-}


--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Elem
-- Copyright   :  (c) 2020 James Koppel
-- License     :  BSD3
--
-- Defines the `Elem` type so that @Elem f fs@ is type-level evidence that @f@ is a member of
-- the list of types @fs@
--
--------------------------------------------------------------------------------

module Data.Comp.Elem
       ( Elem
       , pattern Elem
       , Mem
       , witness
       , elemEq
       , comparePos
       , extend
       , contract
       , unsafeElem
       ) where

import Data.Proxy
import GHC.TypeLits
import Data.Type.Equality
import qualified Unsafe.Coerce as U

-- |  @Elem f fs@ is type-level evidence that @f@ is a member of the list of types @fs@.
--    The runtime value is just an int representing the index of @f@ in @fs@.
--    The `Elem` pattern . The safe constructor is `witness`.
data Elem (f :: k) (fs :: [k]) where
  Elem# :: Int -> Elem f fs

-- | Access the underlying int of an `Elem`. Only usable as a destructor.
pattern Elem :: Int -> Elem f fs
pattern $mElem :: forall r k (f :: k) (fs :: [k]).
Elem f fs -> (Int -> r) -> (Void# -> r) -> r
Elem n <- Elem# n


type family Position (f :: k) (fs :: [k]) where
  Position (f :: k) ((f :: k) ': fs) = 0
  Position f (g ': fs) = 1 + Position f fs

-- | `Mem f fs` holds if the typechecker can statically deduce that `f` is contained in `fs`
class (KnownNat (Position f fs)) => Mem (f :: k) (fs :: [k])
instance (KnownNat (Position f fs)) => Mem f fs

-- | Safe constructor for `Elem`. If the typechecker can deduce that @f@ is in @fs@,
--   then `witness` creates an `Elem f fs` witnessing that inclusion.
witness :: forall f fs. (Mem f fs) => Elem f fs
witness :: Elem f fs
witness = Int -> Elem f fs
forall k (f :: k) (fs :: [k]). Int -> Elem f fs
Elem# Int
pos
  where pos :: Int
pos = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy (Position f fs) -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy (Position f fs)
forall k (t :: k). Proxy t
Proxy :: Proxy (Position f fs)))
{-# INLINE witness #-}

{-# INLINE elemEq #-}
elemEq :: forall f g fs. Elem f fs -> Elem g fs -> Maybe (f :~: g)
elemEq :: Elem f fs -> Elem g fs -> Maybe (f :~: g)
elemEq (Elem v1 :: Int
v1) (Elem v2 :: Int
v2) = case Int
v1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
v2 of
  True -> (f :~: g) -> Maybe (f :~: g)
forall a. a -> Maybe a
Just ((Any :~: Any) -> f :~: g
forall a b. a -> b
U.unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl)
  False -> Maybe (f :~: g)
forall a. Maybe a
Nothing

{-# INLINE comparePos #-}
comparePos :: Elem f fs -> Elem g fs -> Ordering
comparePos :: Elem f fs -> Elem g fs -> Ordering
comparePos (Elem v1 :: Int
v1) (Elem v2 :: Int
v2) = Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
v1 Int
v2

extend :: Elem f fs -> Elem f (g ': fs)
extend :: Elem f fs -> Elem f (g : fs)
extend (Elem i :: Int
i) = Int -> Elem f (g : fs)
forall k (f :: k) (fs :: [k]). Int -> Elem f fs
Elem# (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1)
{-# INLINE extend #-}

{-# INLINE contract #-}
contract :: Elem f (g ': fs) -> Either (f :~: g) (Elem f fs)
contract :: Elem f (g : fs) -> Either (f :~: g) (Elem f fs)
contract (Elem i :: Int
i)
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 0     = Elem f fs -> Either (f :~: g) (Elem f fs)
forall a b. b -> Either a b
Right (Int -> Elem f fs
forall k (f :: k) (fs :: [k]). Int -> Elem f fs
Elem# (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1))
  | Bool
otherwise = (f :~: g) -> Either (f :~: g) (Elem f fs)
forall a b. a -> Either a b
Left ((Any :~: Any) -> f :~: g
forall a b. a -> b
U.unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl)

-- | Completely unsafe. USE WITH CARE.
{-# INLINE unsafeElem #-}
unsafeElem :: Elem f fs -> Elem g gs
unsafeElem :: Elem f fs -> Elem g gs
unsafeElem (Elem e :: Int
e) = Int -> Elem g gs
forall k (f :: k) (fs :: [k]). Int -> Elem f fs
Elem# Int
e