{-# LANGUAGE ConstraintKinds        #-}
{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE PolyKinds              #-}
{-# LANGUAGE RankNTypes             #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE TypeSynonymInstances   #-}
{-# LANGUAGE UndecidableInstances   #-}

--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Ops
-- Copyright   :  Original (c) 2010-2011 Patrick Bahr, Tom Hvitved; modifications (c) 2020 James Koppel
-- License     :  BSD3
--
-- This module provides sum and product operators on functors.
--
-- The main item of interest is the `Sum` datatype, a constant-memory implementation of type-level sums.
--
--------------------------------------------------------------------------------

module Data.Comp.Ops
        ( module Data.Comp.Ops
        , Alts
        , Alt
        , alt
        , (<|)
        , cons
        , nil
        ) where

import Data.Foldable
import Data.Traversable
import Data.Proxy
import Data.Functor.Identity

import Control.Applicative
import Control.Monad hiding (mapM, sequence)
import Data.Type.Equality
import Data.Comp.Elem
import Data.Comp.Dict
import Data.Comp.Alt

import Prelude hiding (foldl, foldl1, foldr, foldr1, mapM, sequence)

infixr 6 :+:

-- |Data type defining coproducts.
data (f :+: g) (a :: *) = Inl (f a)
                       | Inr (g a)


-- | See documentation for the mult-sorted version, `Data.Comp.Multi.Ops.Sum`
data Sum (fs :: [* -> *]) e where
  Sum :: Elem f fs -> f e -> Sum fs e

at :: Elem f fs -> Sum fs a -> Maybe (f a)
at :: Elem f fs -> Sum fs a -> Maybe (f a)
at e :: Elem f fs
e (Sum wit :: Elem f fs
wit a :: f a
a) =
  case Elem f fs -> Elem f fs -> Maybe (f :~: f)
forall k (f :: k) (g :: k) (fs :: [k]).
Elem f fs -> Elem g fs -> Maybe (f :~: g)
elemEq Elem f fs
e Elem f fs
wit of
    Just Refl -> f a -> Maybe (f a)
forall a. a -> Maybe a
Just f a
a
    Nothing   -> Maybe (f a)
forall a. Maybe a
Nothing

{-| Utility function to case on a functor sum, without exposing the internal
  representation of sums. -}
{-# INLINE caseF #-}
caseF :: Alts fs a b -> Sum fs a -> b
caseF :: Alts fs a b -> Sum fs a -> b
caseF alts :: Alts fs a b
alts (Sum wit :: Elem f fs
wit v :: f a
v) = Elem f fs -> Alts fs a b -> f a -> b
forall (f :: * -> *) (fs :: [* -> *]) e b.
Elem f fs -> Alts fs e b -> f e -> b
extractAt Elem f fs
wit Alts fs a b
alts f a
v

{-# INLINE caseCxt #-}
caseCxt :: forall cxt fs a b. (All cxt fs) => Proxy cxt -> (forall f. (cxt f) => f a -> b) -> Sum fs a -> b
caseCxt :: Proxy cxt
-> (forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt _ f :: forall (f :: * -> *). cxt f => f a -> b
f (Sum wit :: Elem f fs
wit v :: f a
v) = f a -> b
forall (f :: * -> *). cxt f => f a -> b
f f a
v (cxt f => b) -> Dict cxt f -> b
forall k (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ Elem f fs -> Dict cxt f
forall k (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
dictFor @cxt Elem f fs
wit

{-# INLINE caseSumF #-}
caseSumF :: forall cxt f fs a b. (All cxt fs, Functor f) => Proxy cxt -> (forall g. (cxt g) => g a -> f (g b)) -> Sum fs a -> f (Sum fs b)
caseSumF :: Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a
-> f (Sum fs b)
caseSumF _ f :: forall (g :: * -> *). cxt g => g a -> f (g b)
f (Sum wit :: Elem f fs
wit v :: f a
v) = Elem f fs -> f b -> Sum fs b
forall (f :: * -> *) (fs :: [* -> *]) e.
Elem f fs -> f e -> Sum fs e
Sum Elem f fs
wit (f b -> Sum fs b) -> f (f b) -> f (Sum fs b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a -> f (f b)
forall (g :: * -> *). cxt g => g a -> f (g b)
f f a
v (cxt f => f (Sum fs b)) -> Dict cxt f -> f (Sum fs b)
forall k (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ Elem f fs -> Dict cxt f
forall k (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
dictFor @cxt Elem f fs
wit

{-# INLINE caseSum #-}
caseSum :: forall cxt fs a b. (All cxt fs) => Proxy cxt -> (forall g. (cxt g) => g a -> g b) -> Sum fs a -> Sum fs b
caseSum :: Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> g b)
-> Sum fs a
-> Sum fs b
caseSum p :: Proxy cxt
p f :: forall (g :: * -> *). cxt g => g a -> g b
f = Identity (Sum fs b) -> Sum fs b
forall a. Identity a -> a
runIdentity (Identity (Sum fs b) -> Sum fs b)
-> (Sum fs a -> Identity (Sum fs b)) -> Sum fs a -> Sum fs b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> Identity (g b))
-> Sum fs a
-> Identity (Sum fs b)
forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a
-> f (Sum fs b)
caseSumF Proxy cxt
p (g b -> Identity (g b)
forall a. a -> Identity a
Identity (g b -> Identity (g b)) -> (g a -> g b) -> g a -> Identity (g b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g a -> g b
forall (g :: * -> *). cxt g => g a -> g b
f)

instance (All Functor fs) => Functor (Sum fs) where
    fmap :: (a -> b) -> Sum fs a -> Sum fs b
fmap f :: a -> b
f = Proxy Functor
-> (forall (g :: * -> *). Functor g => g a -> g b)
-> Sum fs a
-> Sum fs b
forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> g b)
-> Sum fs a
-> Sum fs b
caseSum (Proxy Functor
forall k (t :: k). Proxy t
Proxy @Functor) ((a -> b) -> g a -> g b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f)

instance ( All Foldable fs
         ) => Foldable (Sum fs) where
    fold :: Sum fs m -> m
fold      = Proxy Foldable
-> (forall (f :: * -> *). Foldable f => f m -> m) -> Sum fs m -> m
forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
Proxy cxt
-> (forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt (Proxy Foldable
forall k (t :: k). Proxy t
Proxy @Foldable) forall (f :: * -> *). Foldable f => f m -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold
    foldMap :: (a -> m) -> Sum fs a -> m
foldMap f :: a -> m
f = Proxy Foldable
-> (forall (f :: * -> *). Foldable f => f a -> m) -> Sum fs a -> m
forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
Proxy cxt
-> (forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt (Proxy Foldable
forall k (t :: k). Proxy t
Proxy @Foldable) ((a -> m) -> f a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f)
    foldr :: (a -> b -> b) -> b -> Sum fs a -> b
foldr f :: a -> b -> b
f b :: b
b = Proxy Foldable
-> (forall (f :: * -> *). Foldable f => f a -> b) -> Sum fs a -> b
forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
Proxy cxt
-> (forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt (Proxy Foldable
forall k (t :: k). Proxy t
Proxy @Foldable) ((a -> b -> b) -> b -> f a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
b)
    foldl :: (b -> a -> b) -> b -> Sum fs a -> b
foldl f :: b -> a -> b
f b :: b
b = Proxy Foldable
-> (forall (f :: * -> *). Foldable f => f a -> b) -> Sum fs a -> b
forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
Proxy cxt
-> (forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt (Proxy Foldable
forall k (t :: k). Proxy t
Proxy @Foldable) ((b -> a -> b) -> b -> f a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl b -> a -> b
f b
b)
    foldr1 :: (a -> a -> a) -> Sum fs a -> a
foldr1 f :: a -> a -> a
f  = Proxy Foldable
-> (forall (f :: * -> *). Foldable f => f a -> a) -> Sum fs a -> a
forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
Proxy cxt
-> (forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt (Proxy Foldable
forall k (t :: k). Proxy t
Proxy @Foldable) ((a -> a -> a) -> f a -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 a -> a -> a
f)
    foldl1 :: (a -> a -> a) -> Sum fs a -> a
foldl1 f :: a -> a -> a
f  = Proxy Foldable
-> (forall (f :: * -> *). Foldable f => f a -> a) -> Sum fs a -> a
forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
Proxy cxt
-> (forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt (Proxy Foldable
forall k (t :: k). Proxy t
Proxy @Foldable) ((a -> a -> a) -> f a -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 a -> a -> a
f)

instance ( All Traversable fs
         , All Functor fs
         , All Foldable fs
         ) => Traversable (Sum fs) where
    traverse :: (a -> f b) -> Sum fs a -> f (Sum fs b)
traverse f :: a -> f b
f = Proxy Traversable
-> (forall (g :: * -> *). Traversable g => g a -> f (g b))
-> Sum fs a
-> f (Sum fs b)
forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a
-> f (Sum fs b)
caseSumF (Proxy Traversable
forall k (t :: k). Proxy t
Proxy @Traversable) ((a -> f b) -> g a -> f (g b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f)
    sequenceA :: Sum fs (f a) -> f (Sum fs a)
sequenceA  = Proxy Traversable
-> (forall (g :: * -> *). Traversable g => g (f a) -> f (g a))
-> Sum fs (f a)
-> f (Sum fs a)
forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a
-> f (Sum fs b)
caseSumF (Proxy Traversable
forall k (t :: k). Proxy t
Proxy @Traversable) forall (g :: * -> *). Traversable g => g (f a) -> f (g a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA
    mapM :: (a -> m b) -> Sum fs a -> m (Sum fs b)
mapM f :: a -> m b
f     = Proxy Traversable
-> (forall (g :: * -> *). Traversable g => g a -> m (g b))
-> Sum fs a
-> m (Sum fs b)
forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a
-> f (Sum fs b)
caseSumF (Proxy Traversable
forall k (t :: k). Proxy t
Proxy @Traversable) ((a -> m b) -> g a -> m (g b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> m b
f)
    sequence :: Sum fs (m a) -> m (Sum fs a)
sequence   = Proxy Traversable
-> (forall (g :: * -> *). Traversable g => g (m a) -> m (g a))
-> Sum fs (m a)
-> m (Sum fs a)
forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
Proxy cxt
-> (forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a
-> f (Sum fs b)
caseSumF (Proxy Traversable
forall k (t :: k). Proxy t
Proxy @Traversable) forall (g :: * -> *). Traversable g => g (m a) -> m (g a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence

infixl 5 :<:
-- infixl 5 :=:

class (f :: * -> *) :<: (g :: * -> *) where
  inj  :: f a -> g a
  prj  :: g a -> Maybe (f a)

instance ( Functor f
         , Mem f fs
         ) => f :<: (Sum fs) where
  inj :: f a -> Sum fs a
inj = Elem f fs -> f a -> Sum fs a
forall (f :: * -> *) (fs :: [* -> *]) e.
Elem f fs -> f e -> Sum fs e
Sum Elem f fs
forall k (f :: k) (fs :: [k]). Mem f fs => Elem f fs
witness
  prj :: Sum fs a -> Maybe (f a)
prj = Elem f fs -> Sum fs a -> Maybe (f a)
forall (f :: * -> *) (fs :: [* -> *]) a.
Elem f fs -> Sum fs a -> Maybe (f a)
at Elem f fs
forall k (f :: k) (fs :: [k]). Mem f fs => Elem f fs
witness

instance a :<: a where
  inj :: a a -> a a
inj = a a -> a a
forall a. a -> a
id
  prj :: a a -> Maybe (a a)
prj = a a -> Maybe (a a)
forall a. a -> Maybe a
Just

-- | A constraint @f :<: g@ expresses that the signature @f@ is
-- subsumed by @g@, i.e. @f@ can be used to construct elements in @g@.

type f :=: g = (f :<: g, g :<: f)

spl :: ( f :=: Sum fs
       ) => Alts fs a b -> f a -> b
spl :: Alts fs a b -> f a -> b
spl alts :: Alts fs a b
alts = Alts fs a b -> Sum fs a -> b
forall (fs :: [* -> *]) a b. Alts fs a b -> Sum fs a -> b
caseF Alts fs a b
alts (Sum fs a -> b) -> (f a -> Sum fs a) -> f a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> Sum fs a
forall (f :: * -> *) (g :: * -> *) a. (f :<: g) => f a -> g a
inj

-- Products

infixr 8 :*:

-- |Formal product of signatures (functors).
data (f :*: g) a = f a :*: g a


ffst :: (f :*: g) a -> f a
ffst :: (:*:) f g a -> f a
ffst (x :: f a
x :*: _) = f a
x

fsnd :: (f :*: g) a -> g a
fsnd :: (:*:) f g a -> g a
fsnd (_ :*: x :: g a
x) = g a
x

instance (Functor f, Functor g) => Functor (f :*: g) where
    fmap :: (a -> b) -> (:*:) f g a -> (:*:) f g b
fmap h :: a -> b
h (f :: f a
f :*: g :: g a
g) = ((a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
h f a
f f b -> g b -> (:*:) f g b
forall k (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
:*: (a -> b) -> g a -> g b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
h g a
g)


instance (Foldable f, Foldable g) => Foldable (f :*: g) where
    foldr :: (a -> b -> b) -> b -> (:*:) f g a -> b
foldr f :: a -> b -> b
f e :: b
e (x :: f a
x :*: y :: g a
y) = (a -> b -> b) -> b -> f a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f ((a -> b -> b) -> b -> g a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
e g a
y) f a
x
    foldl :: (b -> a -> b) -> b -> (:*:) f g a -> b
foldl f :: b -> a -> b
f e :: b
e (x :: f a
x :*: y :: g a
y) = (b -> a -> b) -> b -> g a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl b -> a -> b
f ((b -> a -> b) -> b -> f a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl b -> a -> b
f b
e f a
x) g a
y


instance (Traversable f, Traversable g) => Traversable (f :*: g) where
    traverse :: (a -> f b) -> (:*:) f g a -> f ((:*:) f g b)
traverse f :: a -> f b
f (x :: f a
x :*: y :: g a
y) = (f b -> g b -> (:*:) f g b)
-> f (f b) -> f (g b) -> f ((:*:) f g b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 f b -> g b -> (:*:) f g b
forall k (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
(:*:) ((a -> f b) -> f a -> f (f b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f f a
x) ((a -> f b) -> g a -> f (g b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f g a
y)
    sequenceA :: (:*:) f g (f a) -> f ((:*:) f g a)
sequenceA (x :: f (f a)
x :*: y :: g (f a)
y) = (f a -> g a -> (:*:) f g a)
-> f (f a) -> f (g a) -> f ((:*:) f g a)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 f a -> g a -> (:*:) f g a
forall k (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
(:*:)(f (f a) -> f (f a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA f (f a)
x) (g (f a) -> f (g a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA g (f a)
y)
    mapM :: (a -> m b) -> (:*:) f g a -> m ((:*:) f g b)
mapM f :: a -> m b
f (x :: f a
x :*: y :: g a
y) = (f b -> g b -> (:*:) f g b)
-> m (f b) -> m (g b) -> m ((:*:) f g b)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 f b -> g b -> (:*:) f g b
forall k (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
(:*:) ((a -> m b) -> f a -> m (f b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> m b
f f a
x) ((a -> m b) -> g a -> m (g b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> m b
f g a
y)
    sequence :: (:*:) f g (m a) -> m ((:*:) f g a)
sequence (x :: f (m a)
x :*: y :: g (m a)
y) = (f a -> g a -> (:*:) f g a)
-> m (f a) -> m (g a) -> m ((:*:) f g a)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 f a -> g a -> (:*:) f g a
forall k (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
(:*:) (f (m a) -> m (f a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence f (m a)
x) (g (m a) -> m (g a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence g (m a)
y)