{-# LANGUAGE ConstraintKinds        #-}
{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE PolyKinds              #-}
{-# LANGUAGE RankNTypes             #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE TypeSynonymInstances   #-}
{-# LANGUAGE UndecidableInstances   #-}

--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Ops
-- Copyright   :  Original (c) 2010-2011 Patrick Bahr, Tom Hvitved; modifications (c) 2020 James Koppel
-- License     :  BSD3
--
-- This module provides sum and product operators on functors.
--
-- The main item of interest is the `Sum` datatype, a constant-memory implementation of type-level sums.
--
--------------------------------------------------------------------------------

module Data.Comp.Ops
        ( module Data.Comp.Ops
        , Alts
        , Alt
        , alt
        , (<|)
        , cons
        , nil
        ) where

import Data.Foldable
import Data.Traversable
import Data.Functor.Identity

import Control.Applicative
import Control.Monad hiding (mapM, sequence)
import Data.Type.Equality
import Data.Comp.Elem
import Data.Comp.Dict
import Data.Comp.Alt

import Prelude hiding (foldl, foldl1, foldr, foldr1, mapM, sequence)

infixr 6 :+:

-- |Data type defining coproducts.
data (f :+: g) (a :: *) = Inl (f a)
                       | Inr (g a)


-- | See documentation for the mult-sorted version, `Data.Comp.Multi.Ops.Sum`
data Sum (fs :: [* -> *]) e where
  Sum :: Elem f fs -> f e -> Sum fs e

at :: Elem f fs -> Sum fs a -> Maybe (f a)
at :: forall (f :: * -> *) (fs :: [* -> *]) a.
Elem f fs -> Sum fs a -> Maybe (f a)
at Elem f fs
e (Sum Elem f fs
wit f a
a) =
  case Elem f fs -> Elem f fs -> Maybe (f :~: f)
forall {k} (f :: k) (g :: k) (fs :: [k]).
Elem f fs -> Elem g fs -> Maybe (f :~: g)
elemEq Elem f fs
e Elem f fs
wit of
    Just f :~: f
Refl -> f a -> Maybe (f a)
forall a. a -> Maybe a
Just f a
f a
a
    Maybe (f :~: f)
Nothing   -> Maybe (f a)
forall a. Maybe a
Nothing

{-| Utility function to case on a functor sum, without exposing the internal
  representation of sums. -}
{-# INLINE caseF #-}
caseF :: Alts fs a b -> Sum fs a -> b
caseF :: forall (fs :: [* -> *]) a b. Alts fs a b -> Sum fs a -> b
caseF Alts fs a b
alts (Sum Elem f fs
wit f a
v) = Elem f fs -> Alts fs a b -> f a -> b
forall (f :: * -> *) (fs :: [* -> *]) e b.
Elem f fs -> Alts fs e b -> f e -> b
extractAt Elem f fs
wit Alts fs a b
alts f a
v

{-# INLINE caseCxt #-}
caseCxt :: forall cxt fs a b. (All cxt fs) => (forall f. (cxt f) => f a -> b) -> Sum fs a -> b
caseCxt :: forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt forall (f :: * -> *). cxt f => f a -> b
f (Sum Elem f fs
wit f a
v) = f a -> b
forall (f :: * -> *). cxt f => f a -> b
f f a
v (cxt f => b) -> Dict cxt f -> b
forall {k} (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ forall {k} (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
forall (c :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]).
All c fs =>
Elem f fs -> Dict c f
dictFor @cxt Elem f fs
wit

{-# INLINE caseSumF #-}
caseSumF :: forall cxt f fs a b. (All cxt fs, Functor f) => (forall g. (cxt g) => g a -> f (g b)) -> Sum fs a -> f (Sum fs b)
caseSumF :: forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
(forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a -> f (Sum fs b)
caseSumF forall (g :: * -> *). cxt g => g a -> f (g b)
f (Sum Elem f fs
wit f a
v) = Elem f fs -> f b -> Sum fs b
forall (f :: * -> *) (fs :: [* -> *]) e.
Elem f fs -> f e -> Sum fs e
Sum Elem f fs
wit (f b -> Sum fs b) -> f (f b) -> f (Sum fs b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a -> f (f b)
forall (g :: * -> *). cxt g => g a -> f (g b)
f f a
v (cxt f => f (Sum fs b)) -> Dict cxt f -> f (Sum fs b)
forall {k} (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ forall {k} (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
forall (c :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]).
All c fs =>
Elem f fs -> Dict c f
dictFor @cxt Elem f fs
wit

{-# INLINE caseSum #-}
caseSum :: forall cxt fs a b. (All cxt fs) => (forall g. (cxt g) => g a -> g b) -> Sum fs a -> Sum fs b
caseSum :: forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (g :: * -> *). cxt g => g a -> g b) -> Sum fs a -> Sum fs b
caseSum forall (g :: * -> *). cxt g => g a -> g b
f = Identity (Sum fs b) -> Sum fs b
forall a. Identity a -> a
runIdentity (Identity (Sum fs b) -> Sum fs b)
-> (Sum fs a -> Identity (Sum fs b)) -> Sum fs a -> Sum fs b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
(forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a -> f (Sum fs b)
caseSumF @cxt (g b -> Identity (g b)
forall a. a -> Identity a
Identity (g b -> Identity (g b)) -> (g a -> g b) -> g a -> Identity (g b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g a -> g b
forall (g :: * -> *). cxt g => g a -> g b
f)

instance (All Functor fs) => Functor (Sum fs) where
    fmap :: forall a b. (a -> b) -> Sum fs a -> Sum fs b
fmap a -> b
f = forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (g :: * -> *). cxt g => g a -> g b) -> Sum fs a -> Sum fs b
caseSum @Functor ((a -> b) -> g a -> g b
forall a b. (a -> b) -> g a -> g b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f)

instance ( All Foldable fs
         ) => Foldable (Sum fs) where
    fold :: forall m. Monoid m => Sum fs m -> m
fold      = forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt @Foldable f m -> m
forall m. Monoid m => f m -> m
forall (f :: * -> *). Foldable f => f m -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold
    foldMap :: forall m a. Monoid m => (a -> m) -> Sum fs a -> m
foldMap a -> m
f = forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt @Foldable ((a -> m) -> f a -> m
forall m a. Monoid m => (a -> m) -> f a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f)
    foldr :: forall a b. (a -> b -> b) -> b -> Sum fs a -> b
foldr a -> b -> b
f b
b = forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt @Foldable ((a -> b -> b) -> b -> f a -> b
forall a b. (a -> b -> b) -> b -> f a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
b)
    foldl :: forall b a. (b -> a -> b) -> b -> Sum fs a -> b
foldl b -> a -> b
f b
b = forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt @Foldable ((b -> a -> b) -> b -> f a -> b
forall b a. (b -> a -> b) -> b -> f a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl b -> a -> b
f b
b)
    foldr1 :: forall a. (a -> a -> a) -> Sum fs a -> a
foldr1 a -> a -> a
f  = forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt @Foldable ((a -> a -> a) -> f a -> a
forall a. (a -> a -> a) -> f a -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 a -> a -> a
f)
    foldl1 :: forall a. (a -> a -> a) -> Sum fs a -> a
foldl1 a -> a -> a
f  = forall (cxt :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
All cxt fs =>
(forall (f :: * -> *). cxt f => f a -> b) -> Sum fs a -> b
caseCxt @Foldable ((a -> a -> a) -> f a -> a
forall a. (a -> a -> a) -> f a -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 a -> a -> a
f)

instance ( All Traversable fs
         , All Functor fs
         , All Foldable fs
         ) => Traversable (Sum fs) where
    traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Sum fs a -> f (Sum fs b)
traverse a -> f b
f = forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
(forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a -> f (Sum fs b)
caseSumF @Traversable ((a -> f b) -> g a -> f (g b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> g a -> f (g b)
traverse a -> f b
f)
    sequenceA :: forall (f :: * -> *) a.
Applicative f =>
Sum fs (f a) -> f (Sum fs a)
sequenceA  = forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
(forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a -> f (Sum fs b)
caseSumF @Traversable g (f a) -> f (g a)
forall (g :: * -> *). Traversable g => g (f a) -> f (g a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a. Applicative f => g (f a) -> f (g a)
sequenceA
    mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Sum fs a -> m (Sum fs b)
mapM a -> m b
f     = forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
(forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a -> f (Sum fs b)
caseSumF @Traversable ((a -> m b) -> g a -> m (g b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> g a -> m (g b)
mapM a -> m b
f)
    sequence :: forall (m :: * -> *) a. Monad m => Sum fs (m a) -> m (Sum fs a)
sequence   = forall (cxt :: (* -> *) -> Constraint) (f :: * -> *)
       (fs :: [* -> *]) a b.
(All cxt fs, Functor f) =>
(forall (g :: * -> *). cxt g => g a -> f (g b))
-> Sum fs a -> f (Sum fs b)
caseSumF @Traversable g (m a) -> m (g a)
forall (g :: * -> *). Traversable g => g (m a) -> m (g a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => g (m a) -> m (g a)
sequence

infixl 5 :<:
-- infixl 5 :=:

class (f :: * -> *) :<: (g :: * -> *) where
  inj  :: f a -> g a
  prj  :: g a -> Maybe (f a)

instance ( Functor f
         , Mem f fs
         ) => f :<: (Sum fs) where
  inj :: forall a. f a -> Sum fs a
inj = Elem f fs -> f a -> Sum fs a
forall (f :: * -> *) (fs :: [* -> *]) e.
Elem f fs -> f e -> Sum fs e
Sum Elem f fs
forall {k} (f :: k) (fs :: [k]). Mem f fs => Elem f fs
witness
  prj :: forall a. Sum fs a -> Maybe (f a)
prj = Elem f fs -> Sum fs a -> Maybe (f a)
forall (f :: * -> *) (fs :: [* -> *]) a.
Elem f fs -> Sum fs a -> Maybe (f a)
at Elem f fs
forall {k} (f :: k) (fs :: [k]). Mem f fs => Elem f fs
witness

instance a :<: a where
  inj :: forall a. a a -> a a
inj = a a -> a a
forall a. a -> a
id
  prj :: forall a. a a -> Maybe (a a)
prj = a a -> Maybe (a a)
forall a. a -> Maybe a
Just

-- | A constraint @f :<: g@ expresses that the signature @f@ is
-- subsumed by @g@, i.e. @f@ can be used to construct elements in @g@.

type f :=: g = (f :<: g, g :<: f)

spl :: ( f :=: Sum fs
       ) => Alts fs a b -> f a -> b
spl :: forall (f :: * -> *) (fs :: [* -> *]) a b.
(f :=: Sum fs) =>
Alts fs a b -> f a -> b
spl Alts fs a b
alts = Alts fs a b -> Sum fs a -> b
forall (fs :: [* -> *]) a b. Alts fs a b -> Sum fs a -> b
caseF Alts fs a b
alts (Sum fs a -> b) -> (f a -> Sum fs a) -> f a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> Sum fs a
forall a. f a -> Sum fs a
forall (f :: * -> *) (g :: * -> *) a. (f :<: g) => f a -> g a
inj

-- Products

infixr 8 :*:

-- |Formal product of signatures (functors).
data (f :*: g) a = f a :*: g a


ffst :: (f :*: g) a -> f a
ffst :: forall {k} (f :: k -> *) (g :: k -> *) (a :: k). (:*:) f g a -> f a
ffst (f a
x :*: g a
_) = f a
x

fsnd :: (f :*: g) a -> g a
fsnd :: forall {k} (f :: k -> *) (g :: k -> *) (a :: k). (:*:) f g a -> g a
fsnd (f a
_ :*: g a
x) = g a
x

instance (Functor f, Functor g) => Functor (f :*: g) where
    fmap :: forall a b. (a -> b) -> (:*:) f g a -> (:*:) f g b
fmap a -> b
h (f a
f :*: g a
g) = ((a -> b) -> f a -> f b
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
h f a
f f b -> g b -> (:*:) f g b
forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
:*: (a -> b) -> g a -> g b
forall a b. (a -> b) -> g a -> g b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
h g a
g)


instance (Foldable f, Foldable g) => Foldable (f :*: g) where
    foldr :: forall a b. (a -> b -> b) -> b -> (:*:) f g a -> b
foldr a -> b -> b
f b
e (f a
x :*: g a
y) = (a -> b -> b) -> b -> f a -> b
forall a b. (a -> b -> b) -> b -> f a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f ((a -> b -> b) -> b -> g a -> b
forall a b. (a -> b -> b) -> b -> g a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
e g a
y) f a
x
    foldl :: forall b a. (b -> a -> b) -> b -> (:*:) f g a -> b
foldl b -> a -> b
f b
e (f a
x :*: g a
y) = (b -> a -> b) -> b -> g a -> b
forall b a. (b -> a -> b) -> b -> g a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl b -> a -> b
f ((b -> a -> b) -> b -> f a -> b
forall b a. (b -> a -> b) -> b -> f a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl b -> a -> b
f b
e f a
x) g a
y


instance (Traversable f, Traversable g) => Traversable (f :*: g) where
    traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> (:*:) f g a -> f ((:*:) f g b)
traverse a -> f b
f (f a
x :*: g a
y) = (f b -> g b -> (:*:) f g b)
-> f (f b) -> f (g b) -> f ((:*:) f g b)
forall a b c. (a -> b -> c) -> f a -> f b -> f c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 f b -> g b -> (:*:) f g b
forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
(:*:) ((a -> f b) -> f a -> f (f b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse a -> f b
f f a
x) ((a -> f b) -> g a -> f (g b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> g a -> f (g b)
traverse a -> f b
f g a
y)
    sequenceA :: forall (f :: * -> *) a.
Applicative f =>
(:*:) f g (f a) -> f ((:*:) f g a)
sequenceA (f (f a)
x :*: g (f a)
y) = (f a -> g a -> (:*:) f g a)
-> f (f a) -> f (g a) -> f ((:*:) f g a)
forall a b c. (a -> b -> c) -> f a -> f b -> f c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 f a -> g a -> (:*:) f g a
forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
(:*:)(f (f a) -> f (f a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a. Applicative f => f (f a) -> f (f a)
sequenceA f (f a)
x) (g (f a) -> f (g a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a. Applicative f => g (f a) -> f (g a)
sequenceA g (f a)
y)
    mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> (:*:) f g a -> m ((:*:) f g b)
mapM a -> m b
f (f a
x :*: g a
y) = (f b -> g b -> (:*:) f g b)
-> m (f b) -> m (g b) -> m ((:*:) f g b)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 f b -> g b -> (:*:) f g b
forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
(:*:) ((a -> m b) -> f a -> m (f b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> f a -> m (f b)
mapM a -> m b
f f a
x) ((a -> m b) -> g a -> m (g b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> g a -> m (g b)
mapM a -> m b
f g a
y)
    sequence :: forall (m :: * -> *) a.
Monad m =>
(:*:) f g (m a) -> m ((:*:) f g a)
sequence (f (m a)
x :*: g (m a)
y) = (f a -> g a -> (:*:) f g a)
-> m (f a) -> m (g a) -> m ((:*:) f g a)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 f a -> g a -> (:*:) f g a
forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> (:*:) f g a
(:*:) (f (m a) -> m (f a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => f (m a) -> m (f a)
sequence f (m a)
x) (g (m a) -> m (g a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => g (m a) -> m (g a)
sequence g (m a)
y)