{-# LANGUAGE ConstraintKinds         #-}
{-# LANGUAGE DataKinds               #-}
{-# LANGUAGE FlexibleContexts        #-}
{-# LANGUAGE FlexibleInstances       #-}
{-# LANGUAGE FunctionalDependencies  #-}
{-# LANGUAGE GADTs                   #-}
{-# LANGUAGE KindSignatures          #-}
{-# LANGUAGE MagicHash               #-}
{-# LANGUAGE MultiParamTypeClasses   #-}
{-# LANGUAGE PartialTypeSignatures   #-}
{-# LANGUAGE PatternSynonyms         #-}
{-# LANGUAGE PolyKinds               #-}
{-# LANGUAGE RankNTypes              #-}
{-# LANGUAGE ScopedTypeVariables     #-}
{-# LANGUAGE TypeApplications        #-}
{-# LANGUAGE TypeFamilies            #-}
{-# LANGUAGE TypeOperators           #-}
{-# LANGUAGE TypeSynonymInstances    #-}
{-# LANGUAGE UndecidableInstances    #-}
{-# LANGUAGE UndecidableSuperClasses #-}

--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Ops
-- Copyright   :  Original (c) 2011 Patrick Bahr; modifications (c) 2020 James Koppel
-- License     :  BSD3
--
-- This module provides operators on higher-order functors. All definitions are
-- generalised versions of those in "Data.Comp.Ops".
--
--------------------------------------------------------------------------------

module Data.Comp.Multi.Ops
    ( Sum (..)
    , caseH
    , (:<:)
    , (:-<:)
    , inj
    , proj
    , (:=:)
    , spl
    , (:&:)(..)
    , RemA(..)
    , (O.:*:)(..)
    , O.ffst
    , O.fsnd
    , unsafeMapSum
    , unsafeElem
    , caseCxt
    , caseSumF
    , caseSum
    , Alts
    , Alt
    , alt
    , (<|)
    , cons
    , nil
    , Elem
    , pattern Elem
    , Mem
    , at
    , witness
    , extend
    , contract
    ) where

import Control.Monad
import Data.Type.Equality
import Data.Proxy
import Data.Functor.Identity
import Data.Comp.Multi.HFoldable
import Data.Comp.Multi.HFunctor
import Data.Comp.Multi.HTraversable
import Data.Comp.Multi.Alt
import qualified Data.Comp.Ops as O
import Data.Comp.Elem
import Data.Comp.Dict


-- | Data type defining a sum of signatures.
--
--   It is inspired by modular reifiable matching, as described in
--
--   * Oliveira, Bruno C. D. S., Shin-Cheng Mu, and Shu-Hung You.
--     \"Modular reifiable matching: a list-of-functors approach to two-level types.\"
--     In Haskell Symposium, 2015.
--
--   except that this definition uses value-level integers (in the `Elem` datatype) in place
--   of type-level naturals. It hence uses `unsafeCoerce` under the hood, but is type-safe if used
--   through the public API. The result is that values of this type take constant memory with respect to the number
--   of summands (unlike vanilla datatypes à la carte), and constant time to dereference
--   (unlike modular reifiable matching). The representation is the bare minimum: an int representing the alternative,
--   and pointer to the value.
data Sum (fs :: [(* -> *) -> * -> *]) h e where
  Sum :: Elem f fs -> f h e -> Sum fs h e

at :: Elem f fs -> Sum fs a e -> Maybe (f a e)
at :: Elem f fs -> Sum fs a e -> Maybe (f a e)
at e :: Elem f fs
e (Sum wit :: Elem f fs
wit a :: f a e
a) =
  case Elem f fs -> Elem f fs -> Maybe (f :~: f)
forall k (f :: k) (g :: k) (fs :: [k]).
Elem f fs -> Elem g fs -> Maybe (f :~: g)
elemEq Elem f fs
e Elem f fs
wit of
    Just Refl -> f a e -> Maybe (f a e)
forall a. a -> Maybe a
Just f a e
a
    Nothing   -> Maybe (f a e)
forall a. Maybe a
Nothing

{-| Utility function to case on a higher-order functor sum, without exposing the
  internal representation of sums. -}
{-# INLINE caseH #-}
caseH :: Alts fs a e b -> Sum fs a e -> b
caseH :: Alts fs a e b -> Sum fs a e -> b
caseH alts :: Alts fs a e b
alts (Sum wit :: Elem f fs
wit v :: f a e
v) = Elem f fs -> Alts fs a e b -> f a e -> b
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (a :: * -> *) e b.
Elem f fs -> Alts fs a e b -> f a e -> b
extractAt Elem f fs
wit Alts fs a e b
alts f a e
v

{-# INLINE caseCxt #-}
caseCxt :: forall cxt fs a e b. (All cxt fs) => Proxy cxt -> (forall f. (cxt f) => f a e -> b) -> Sum fs a e -> b
caseCxt :: Proxy cxt
-> (forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b)
-> Sum fs a e
-> b
caseCxt _ f :: forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b
f (Sum wit :: Elem f fs
wit v :: f a e
v) = f a e -> b
forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b
f f a e
v (cxt f => b) -> Dict cxt f -> b
forall k (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ Elem f fs -> Dict cxt f
forall k (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
dictFor @cxt Elem f fs
wit

{-# INLINE caseSumF #-}
caseSumF :: forall cxt f fs a e b. (All cxt fs, Functor f) => Proxy cxt -> (forall g. (cxt g) => g a e -> f (g b e)) -> Sum fs a e -> f (Sum fs b e)
caseSumF :: Proxy cxt
-> (forall (g :: (* -> *) -> * -> *). cxt g => g a e -> f (g b e))
-> Sum fs a e
-> f (Sum fs b e)
caseSumF _ f :: forall (g :: (* -> *) -> * -> *). cxt g => g a e -> f (g b e)
f (Sum wit :: Elem f fs
wit v :: f a e
v) = Elem f fs -> f b e -> Sum fs b e
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (h :: * -> *) e.
Elem f fs -> f h e -> Sum fs h e
Sum Elem f fs
wit (f b e -> Sum fs b e) -> f (f b e) -> f (Sum fs b e)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a e -> f (f b e)
forall (g :: (* -> *) -> * -> *). cxt g => g a e -> f (g b e)
f f a e
v (cxt f => f (Sum fs b e)) -> Dict cxt f -> f (Sum fs b e)
forall k (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ Elem f fs -> Dict cxt f
forall k (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
dictFor @cxt Elem f fs
wit

{-# INLINE caseSum #-}
caseSum :: forall cxt fs a e b. (All cxt fs) => Proxy cxt -> (forall g. (cxt g) => g a e -> g b e) -> Sum fs a e -> Sum fs b e
caseSum :: Proxy cxt
-> (forall (g :: (* -> *) -> * -> *). cxt g => g a e -> g b e)
-> Sum fs a e
-> Sum fs b e
caseSum p :: Proxy cxt
p f :: forall (g :: (* -> *) -> * -> *). cxt g => g a e -> g b e
f = Identity (Sum fs b e) -> Sum fs b e
forall a. Identity a -> a
runIdentity (Identity (Sum fs b e) -> Sum fs b e)
-> (Sum fs a e -> Identity (Sum fs b e))
-> Sum fs a e
-> Sum fs b e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy cxt
-> (forall (g :: (* -> *) -> * -> *).
    cxt g =>
    g a e -> Identity (g b e))
-> Sum fs a e
-> Identity (Sum fs b e)
forall (cxt :: ((* -> *) -> * -> *) -> Constraint) (f :: * -> *)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e (b :: * -> *).
(All cxt fs, Functor f) =>
Proxy cxt
-> (forall (g :: (* -> *) -> * -> *). cxt g => g a e -> f (g b e))
-> Sum fs a e
-> f (Sum fs b e)
caseSumF Proxy cxt
p (g b e -> Identity (g b e)
forall a. a -> Identity a
Identity (g b e -> Identity (g b e))
-> (g a e -> g b e) -> g a e -> Identity (g b e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g a e -> g b e
forall (g :: (* -> *) -> * -> *). cxt g => g a e -> g b e
f) 

instance (All HFunctor fs) => HFunctor (Sum fs) where
    hfmap :: (f :-> g) -> Sum fs f :-> Sum fs g
hfmap f :: f :-> g
f = Proxy HFunctor
-> (forall (g :: (* -> *) -> * -> *). HFunctor g => g f i -> g g i)
-> Sum fs f i
-> Sum fs g i
forall (cxt :: ((* -> *) -> * -> *) -> Constraint)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e (b :: * -> *).
All cxt fs =>
Proxy cxt
-> (forall (g :: (* -> *) -> * -> *). cxt g => g a e -> g b e)
-> Sum fs a e
-> Sum fs b e
caseSum (Proxy HFunctor
forall k (t :: k). Proxy t
Proxy @HFunctor) ((f :-> g) -> g f :-> g g
forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap f :-> g
f)
      
instance ( All HFoldable fs
         , All HFunctor fs
         ) => HFoldable (Sum fs) where
    hfold :: Sum fs (K m) :=> m
hfold      = Proxy HFoldable
-> (forall (f :: (* -> *) -> * -> *).
    HFoldable f =>
    f (K m) i -> m)
-> Sum fs (K m) i
-> m
forall (cxt :: ((* -> *) -> * -> *) -> Constraint)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
All cxt fs =>
Proxy cxt
-> (forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b)
-> Sum fs a e
-> b
caseCxt (Proxy HFoldable
forall k (t :: k). Proxy t
Proxy @HFoldable) forall (f :: (* -> *) -> * -> *). HFoldable f => f (K m) i -> m
forall (h :: (* -> *) -> * -> *) m.
(HFoldable h, Monoid m) =>
h (K m) :=> m
hfold
    hfoldMap :: (a :=> m) -> Sum fs a :=> m
hfoldMap f :: a :=> m
f = Proxy HFoldable
-> (forall (f :: (* -> *) -> * -> *). HFoldable f => f a i -> m)
-> Sum fs a i
-> m
forall (cxt :: ((* -> *) -> * -> *) -> Constraint)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
All cxt fs =>
Proxy cxt
-> (forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b)
-> Sum fs a e
-> b
caseCxt (Proxy HFoldable
forall k (t :: k). Proxy t
Proxy @HFoldable) ((a :=> m) -> f a :=> m
forall (h :: (* -> *) -> * -> *) m (a :: * -> *).
(HFoldable h, Monoid m) =>
(a :=> m) -> h a :=> m
hfoldMap a :=> m
f)
    hfoldr :: (a :=> (b -> b)) -> b -> Sum fs a :=> b
hfoldr f :: a :=> (b -> b)
f b :: b
b = Proxy HFoldable
-> (forall (f :: (* -> *) -> * -> *). HFoldable f => f a i -> b)
-> Sum fs a i
-> b
forall (cxt :: ((* -> *) -> * -> *) -> Constraint)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
All cxt fs =>
Proxy cxt
-> (forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b)
-> Sum fs a e
-> b
caseCxt (Proxy HFoldable
forall k (t :: k). Proxy t
Proxy @HFoldable) ((a :=> (b -> b)) -> b -> f a :=> b
forall (h :: (* -> *) -> * -> *) (a :: * -> *) b.
HFoldable h =>
(a :=> (b -> b)) -> b -> h a :=> b
hfoldr a :=> (b -> b)
f b
b)
    hfoldl :: (b -> a :=> b) -> b -> Sum fs a :=> b
hfoldl f :: b -> a :=> b
f b :: b
b = Proxy HFoldable
-> (forall (f :: (* -> *) -> * -> *). HFoldable f => f a i -> b)
-> Sum fs a i
-> b
forall (cxt :: ((* -> *) -> * -> *) -> Constraint)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
All cxt fs =>
Proxy cxt
-> (forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b)
-> Sum fs a e
-> b
caseCxt (Proxy HFoldable
forall k (t :: k). Proxy t
Proxy @HFoldable) ((b -> a :=> b) -> b -> f a :=> b
forall (h :: (* -> *) -> * -> *) b (a :: * -> *).
HFoldable h =>
(b -> a :=> b) -> b -> h a :=> b
hfoldl b -> a :=> b
f b
b)
    hfoldr1 :: (a -> a -> a) -> Sum fs (K a) :=> a
hfoldr1 f :: a -> a -> a
f  = Proxy HFoldable
-> (forall (f :: (* -> *) -> * -> *).
    HFoldable f =>
    f (K a) i -> a)
-> Sum fs (K a) i
-> a
forall (cxt :: ((* -> *) -> * -> *) -> Constraint)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
All cxt fs =>
Proxy cxt
-> (forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b)
-> Sum fs a e
-> b
caseCxt (Proxy HFoldable
forall k (t :: k). Proxy t
Proxy @HFoldable) ((a -> a -> a) -> f (K a) :=> a
forall (h :: (* -> *) -> * -> *) a.
HFoldable h =>
(a -> a -> a) -> h (K a) :=> a
hfoldr1 a -> a -> a
f)

instance ( All HTraversable fs
         , All HFoldable fs
         , All HFunctor fs
         ) => HTraversable (Sum fs) where
    htraverse :: NatM f a b -> NatM f (Sum fs a) (Sum fs b)
htraverse f :: NatM f a b
f = Proxy HTraversable
-> (forall (g :: (* -> *) -> * -> *).
    HTraversable g =>
    g a i -> f (g b i))
-> Sum fs a i
-> f (Sum fs b i)
forall (cxt :: ((* -> *) -> * -> *) -> Constraint) (f :: * -> *)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e (b :: * -> *).
(All cxt fs, Functor f) =>
Proxy cxt
-> (forall (g :: (* -> *) -> * -> *). cxt g => g a e -> f (g b e))
-> Sum fs a e
-> f (Sum fs b e)
caseSumF (Proxy HTraversable
forall k (t :: k). Proxy t
Proxy @HTraversable) (NatM f a b -> NatM f (g a) (g b)
forall (t :: (* -> *) -> * -> *) (f :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Applicative f) =>
NatM f a b -> NatM f (t a) (t b)
htraverse NatM f a b
f)
    hmapM :: NatM m a b -> NatM m (Sum fs a) (Sum fs b)
hmapM f :: NatM m a b
f     = Proxy HTraversable
-> (forall (g :: (* -> *) -> * -> *).
    HTraversable g =>
    g a i -> m (g b i))
-> Sum fs a i
-> m (Sum fs b i)
forall (cxt :: ((* -> *) -> * -> *) -> Constraint) (f :: * -> *)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e (b :: * -> *).
(All cxt fs, Functor f) =>
Proxy cxt
-> (forall (g :: (* -> *) -> * -> *). cxt g => g a e -> f (g b e))
-> Sum fs a e
-> f (Sum fs b e)
caseSumF (Proxy HTraversable
forall k (t :: k). Proxy t
Proxy @HTraversable) (NatM m a b -> NatM m (g a) (g b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Monad m) =>
NatM m a b -> NatM m (t a) (t b)
hmapM NatM m a b
f)

-- The subsumption relation.

infixl 5 :<:
infixl 5 :=:

class (f :: (* -> *) -> * -> *) :<: (g :: (* -> *) -> * -> *) where
  inj :: f a :-> g a
  proj :: NatM Maybe (g a) (f a)

instance ( Mem f fs
         ) => f :<: (Sum fs) where
  inj :: f a i -> Sum fs a i
inj = Elem f fs -> f a i -> Sum fs a i
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (h :: * -> *) e.
Elem f fs -> f h e -> Sum fs h e
Sum Elem f fs
forall k (f :: k) (fs :: [k]). Mem f fs => Elem f fs
witness
  proj :: Sum fs a i -> Maybe (f a i)
proj = Elem f fs -> Sum fs a i -> Maybe (f a i)
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (a :: * -> *) e.
Elem f fs -> Sum fs a e -> Maybe (f a e)
at Elem f fs
forall k (f :: k) (fs :: [k]). Mem f fs => Elem f fs
witness

instance f :<: f where
  inj :: f a i -> f a i
inj = f a i -> f a i
forall a. a -> a
id
  proj :: f a i -> Maybe (f a i)
proj = f a i -> Maybe (f a i)
forall a. a -> Maybe a
Just

type f :=: g = (f :<: g, g :<: f)

spl :: ( f :=: Sum fs
      ) => (forall e l. Alts fs a e (b l)) -> f a :-> b
spl :: (forall e l. Alts fs a e (b l)) -> f a :-> b
spl alts :: forall e l. Alts fs a e (b l)
alts = Alts fs a i (b i) -> Sum fs a i -> b i
forall (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
Alts fs a e b -> Sum fs a e -> b
caseH Alts fs a i (b i)
forall e l. Alts fs a e (b l)
alts (Sum fs a i -> b i) -> (f a i -> Sum fs a i) -> f a i -> b i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a i -> Sum fs a i
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (a :: * -> *).
(f :<: g) =>
f a :-> g a
inj

-- Constant Products

infixr 7 :&:

-- | This data type adds a constant product to a
-- signature. Alternatively, this could have also been defined as
--
-- @
-- data (f :&: a) (g ::  * -> *) e = f g e :&: a e
-- @
--
-- This is too general, however, for example for 'productHHom'.

data (f :&: a) (g ::  * -> *) e = f g e :&: a

instance (HFunctor f) => HFunctor (f :&: a) where
    hfmap :: (f :-> g) -> (:&:) f a f :-> (:&:) f a g
hfmap f :: f :-> g
f (v :: f f i
v :&: c :: a
c) = (f :-> g) -> f f i -> f g i
forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap f :-> g
f f f i
v f g i -> a -> (:&:) f a g i
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: a
c

instance (HFoldable f) => HFoldable (f :&: a) where
    hfold :: (:&:) f a (K m) :=> m
hfold (v :: f (K m) i
v :&: _) = f (K m) i -> m
forall (h :: (* -> *) -> * -> *) m.
(HFoldable h, Monoid m) =>
h (K m) :=> m
hfold f (K m) i
v
    hfoldMap :: (a :=> m) -> (:&:) f a a :=> m
hfoldMap f :: a :=> m
f (v :: f a i
v :&: _) = (a :=> m) -> f a i -> m
forall (h :: (* -> *) -> * -> *) m (a :: * -> *).
(HFoldable h, Monoid m) =>
(a :=> m) -> h a :=> m
hfoldMap a :=> m
f f a i
v
    hfoldr :: (a :=> (b -> b)) -> b -> (:&:) f a a :=> b
hfoldr f :: a :=> (b -> b)
f e :: b
e (v :: f a i
v :&: _) = (a :=> (b -> b)) -> b -> f a i -> b
forall (h :: (* -> *) -> * -> *) (a :: * -> *) b.
HFoldable h =>
(a :=> (b -> b)) -> b -> h a :=> b
hfoldr a :=> (b -> b)
f b
e f a i
v
    hfoldl :: (b -> a :=> b) -> b -> (:&:) f a a :=> b
hfoldl f :: b -> a :=> b
f e :: b
e (v :: f a i
v :&: _) = (b -> a :=> b) -> b -> f a i -> b
forall (h :: (* -> *) -> * -> *) b (a :: * -> *).
HFoldable h =>
(b -> a :=> b) -> b -> h a :=> b
hfoldl b -> a :=> b
f b
e f a i
v
    hfoldr1 :: (a -> a -> a) -> (:&:) f a (K a) :=> a
hfoldr1 f :: a -> a -> a
f (v :: f (K a) i
v :&: _) = (a -> a -> a) -> f (K a) i -> a
forall (h :: (* -> *) -> * -> *) a.
HFoldable h =>
(a -> a -> a) -> h (K a) :=> a
hfoldr1 a -> a -> a
f f (K a) i
v
    hfoldl1 :: (a -> a -> a) -> (:&:) f a (K a) :=> a
hfoldl1 f :: a -> a -> a
f (v :: f (K a) i
v :&: _) = (a -> a -> a) -> f (K a) i -> a
forall (h :: (* -> *) -> * -> *) a.
HFoldable h =>
(a -> a -> a) -> h (K a) :=> a
hfoldl1 a -> a -> a
f f (K a) i
v


instance (HTraversable f) => HTraversable (f :&: a) where
    htraverse :: NatM f a b -> NatM f ((:&:) f a a) ((:&:) f a b)
htraverse f :: NatM f a b
f (v :: f a i
v :&: c :: a
c) =  (f b i -> a -> (:&:) f a b i
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: a
c) (f b i -> (:&:) f a b i) -> f (f b i) -> f ((:&:) f a b i)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (NatM f a b -> f a i -> f (f b i)
forall (t :: (* -> *) -> * -> *) (f :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Applicative f) =>
NatM f a b -> NatM f (t a) (t b)
htraverse NatM f a b
f f a i
v)
    hmapM :: NatM m a b -> NatM m ((:&:) f a a) ((:&:) f a b)
hmapM f :: NatM m a b
f (v :: f a i
v :&: c :: a
c) = (f b i -> (:&:) f a b i) -> m (f b i) -> m ((:&:) f a b i)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (f b i -> a -> (:&:) f a b i
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: a
c) (NatM m a b -> f a i -> m (f b i)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Monad m) =>
NatM m a b -> NatM m (t a) (t b)
hmapM NatM m a b
f f a i
v)

class RemA (s :: (* -> *) -> * -> *) s' | s -> s'  where
    remA :: s a :-> s' a

-- TODO: This is linear
--       Is there a way to make this constant time?
instance ( RemA f g
         , RemA (Sum fs) (Sum gs)
         ) => RemA (Sum (f ': fs)) (Sum (g ': gs)) where
  remA :: Sum (f : fs) a i -> Sum (g : gs) a i
remA (Sum w :: Elem f (f : fs)
w a :: f a i
a) = case Elem f (f : fs) -> Either (f :~: f) (Elem f fs)
forall k (f :: k) (g :: k) (fs :: [k]).
Elem f (g : fs) -> Either (f :~: g) (Elem f fs)
contract Elem f (f : fs)
w of
    Left Refl -> Elem g (g : gs) -> g a i -> Sum (g : gs) a i
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (h :: * -> *) e.
Elem f fs -> f h e -> Sum fs h e
Sum Elem g (g : gs)
forall k (f :: k) (fs :: [k]). Mem f fs => Elem f fs
witness (f a i -> g a i
forall (s :: (* -> *) -> * -> *) (s' :: (* -> *) -> * -> *)
       (a :: * -> *).
RemA s s' =>
s a :-> s' a
remA f a i
a)
    Right w0 :: Elem f fs
w0  -> case Sum fs a i -> Sum gs a i
forall (a :: * -> *).
RemA (Sum fs) (Sum gs) =>
Sum fs a :-> Sum gs a
go (Elem f fs -> f a i -> Sum fs a i
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (h :: * -> *) e.
Elem f fs -> f h e -> Sum fs h e
Sum Elem f fs
w0 f a i
a) of
      Sum w1 :: Elem f gs
w1 a :: f a i
a -> Elem f (g : gs) -> f a i -> Sum (g : gs) a i
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (h :: * -> *) e.
Elem f fs -> f h e -> Sum fs h e
Sum (Elem f gs -> Elem f (g : gs)
forall a (f :: a) (fs :: [a]) (g :: a).
Elem f fs -> Elem f (g : fs)
extend Elem f gs
w1) f a i
a

    where go :: (RemA (Sum fs) (Sum gs)) => Sum fs a :-> Sum gs a
          go :: Sum fs a :-> Sum gs a
go = Sum fs a i -> Sum gs a i
forall (s :: (* -> *) -> * -> *) (s' :: (* -> *) -> * -> *)
       (a :: * -> *).
RemA s s' =>
s a :-> s' a
remA

instance RemA (f :&: p) f where
    remA :: (:&:) f p a i -> f a i
remA (v :: f a i
v :&: _) = f a i
v

-- NOTE: Invariant => Length fs == Length gs
-- TODO: write gs as a function of fs.    
unsafeMapSum :: Elem f fs -> f a e -> (f a :-> g a) -> Sum gs a e
unsafeMapSum :: Elem f fs -> f a e -> (f a :-> g a) -> Sum gs a e
unsafeMapSum wit :: Elem f fs
wit v :: f a e
v f :: f a :-> g a
f = Elem g gs -> g a e -> Sum gs a e
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (h :: * -> *) e.
Elem f fs -> f h e -> Sum fs h e
Sum (Elem f fs -> Elem g gs
forall k1 k2 (f :: k1) (fs :: [k1]) (g :: k2) (gs :: [k2]).
Elem f fs -> Elem g gs
unsafeElem Elem f fs
wit) (f a e -> g a e
f a :-> g a
f f a e
v)

class (f :<: Sum fs) => f :-<: fs
instance (f :<: Sum fs) => f :-<: fs