{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE KindSignatures         #-}
{-# LANGUAGE TypeOperators          #-}

module Data.Comp.Multi.Alt
       ( Alts
       , Alt
       , alt
       , (<|)
       , cons
       , nil
       , extractAt
       ) where

import qualified Data.Vector as V
import Data.Vector (Vector)
import qualified Unsafe.Coerce as U
import GHC.Types

import Data.Comp.Elem

newtype Alt f (a :: * -> *) e b = Alt (f a e -> b)

newtype Alts (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b =
  Alts (Vector (Alt Any a e b))

alt :: (f a e -> b) -> Alt f a e b
alt :: (f a e -> b) -> Alt f a e b
alt = (f a e -> b) -> Alt f a e b
forall (f :: (* -> *) -> * -> *) (a :: * -> *) e b.
(f a e -> b) -> Alt f a e b
Alt

infixr 6 <|

{-# INLINE (<|) #-}
(<|) :: Alt f a e b -> Alts fs a e b -> Alts (f ': fs) a e b
<| :: Alt f a e b -> Alts fs a e b -> Alts (f : fs) a e b
(<|) = Alt f a e b -> Alts fs a e b -> Alts (f : fs) a e b
forall (f :: (* -> *) -> * -> *) (a :: * -> *) e b
       (fs :: [(* -> *) -> * -> *]).
Alt f a e b -> Alts fs a e b -> Alts (f : fs) a e b
cons

{-# INLINE cons #-}
cons :: Alt f a e b -> Alts fs a e b -> Alts (f ': fs) a e b
cons :: Alt f a e b -> Alts fs a e b -> Alts (f : fs) a e b
cons a :: Alt f a e b
a (Alts as :: Vector (Alt Any a e b)
as) = Vector (Alt Any a e b) -> Alts (f : fs) a e b
forall (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
Vector (Alt Any a e b) -> Alts fs a e b
Alts (Alt f a e b -> Alt Any a e b
forall a b. a -> b
U.unsafeCoerce Alt f a e b
a Alt Any a e b -> Vector (Alt Any a e b) -> Vector (Alt Any a e b)
forall a. a -> Vector a -> Vector a
`V.cons` Vector (Alt Any a e b)
as)

nil :: Alts '[] a e b
nil :: Alts '[] a e b
nil = Vector (Alt Any a e b) -> Alts '[] a e b
forall (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
Vector (Alt Any a e b) -> Alts fs a e b
Alts Vector (Alt Any a e b)
forall a. Vector a
V.empty

extractAt :: Elem f fs -> Alts fs a e b -> (f a e -> b)
extractAt :: Elem f fs -> Alts fs a e b -> f a e -> b
extractAt (Elem v :: Int
v) (Alts ms :: Vector (Alt Any a e b)
ms) = Alt Any a e b -> f a e -> b
forall a b. a -> b
U.unsafeCoerce (Vector (Alt Any a e b)
ms Vector (Alt Any a e b) -> Int -> Alt Any a e b
forall a. Vector a -> Int -> a
V.! Int
v)