{-# LANGUAGE EmptyDataDecls      #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Multi.Term
-- Copyright   :  (c) 2011 Patrick Bahr
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- This module defines the central notion of mutual recursive (or, higher-order)
-- /terms/ and its generalisation to (higher-order) contexts. All definitions
-- are generalised versions of those in "Data.Comp.Term".
--
--------------------------------------------------------------------------------

module Data.Comp.Multi.Term
    (Cxt (..),
     Hole,
     NoHole,
     Context,
     ContextS,
     HFix,
     Term,
     CxtS,
     Const,
     constTerm,
     unTerm,
     toCxt,
     simpCxt
     ) where

import Data.Comp.Multi.HFoldable
import Data.Comp.Multi.HFunctor
import Data.Comp.Multi.HTraversable
import Data.Comp.Multi.Ops

import Control.Monad

import Unsafe.Coerce

type Const (f :: (* -> *) -> * -> *) = f (K ())

-- | This function converts a constant to a term. This assumes that
-- the argument is indeed a constant, i.e. does not have a value for
-- the argument type of the functor f.

constTerm :: (HFunctor f) => Const f :-> HFix f
constTerm :: Const f :-> HFix f
constTerm = f (HFix f) i -> Cxt NoHole f (K ()) i
forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term (f (HFix f) i -> Cxt NoHole f (K ()) i)
-> (f (K ()) i -> f (HFix f) i)
-> f (K ()) i
-> Cxt NoHole f (K ()) i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (K () :-> HFix f) -> Const f :-> f (HFix f)
forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap (Cxt NoHole f (K ()) i -> K () i -> Cxt NoHole f (K ()) i
forall a b. a -> b -> a
const Cxt NoHole f (K ()) i
forall a. HasCallStack => a
undefined)

-- | This data type represents contexts over a signature. Contexts are
-- terms containing zero or more holes. The first type parameter is
-- supposed to be one of the phantom types 'Hole' and 'NoHole'. The
-- second parameter is the signature of the context. The third
-- parameter is the type family of the holes. The last parameter is
-- the index/label.

data Cxt h f a i where
    Term ::  f (Cxt h f a) i -> Cxt h f a i
    Hole :: a i -> Cxt Hole f a i

-- | Phantom type that signals that a 'Cxt' might contain holes.
data Hole
-- | Phantom type that signals that a 'Cxt' does not contain holes.
data NoHole

type CxtS h fs a = Cxt h (Sum fs) a

-- | A context might contain holes.
type Context = Cxt Hole

type ContextS fs a = CxtS Hole fs a

-- | A (higher-order) term is a context with no holes.
type HFix f = Cxt NoHole f (K ())

type Term fs = HFix (Sum fs)

-- | This function unravels the given term at the topmost layer.
unTerm :: HFix f t -> f (HFix f) t
unTerm :: HFix f t -> f (HFix f) t
unTerm (Term t :: f (HFix f) t
t) = f (HFix f) t
t

instance (HFunctor f) => HFunctor (Cxt h f) where
    hfmap :: (f :-> g) -> Cxt h f f :-> Cxt h f g
hfmap f :: f :-> g
f (Hole x :: f i
x) = g i -> Cxt Hole f g i
forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole (f i -> g i
f :-> g
f f i
x)
    hfmap f :: f :-> g
f (Term t :: f (Cxt h f f) i
t) = f (Cxt h f g) i -> Cxt h f g i
forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term ((Cxt h f f :-> Cxt h f g) -> f (Cxt h f f) i -> f (Cxt h f g) i
forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap ((f :-> g) -> Cxt h f f :-> Cxt h f g
forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap f :-> g
f) f (Cxt h f f) i
t)

instance (HFoldable f) => HFoldable (Cxt h f) where
    hfoldr :: (a :=> (b -> b)) -> b -> Cxt h f a :=> b
hfoldr = (a :=> (b -> b)) -> b -> Cxt h f a i -> b
forall (a :: * -> *) b. (a :=> (b -> b)) -> b -> Cxt h f a :=> b
hfoldr' where
        hfoldr'  :: forall a b. (a :=> (b -> b)) -> b -> Cxt h f a :=> b
        hfoldr' :: (a :=> (b -> b)) -> b -> Cxt h f a :=> b
hfoldr' op :: a :=> (b -> b)
op c :: b
c a :: Cxt h f a i
a = Cxt h f a i -> b -> b
Cxt h f a :=> (b -> b)
run Cxt h f a i
a b
c where
              run :: (Cxt h f) a :=> (b ->  b)
              run :: Cxt h f a i -> b -> b
run (Hole a :: a i
a) e :: b
e = a i
a a i -> b -> b
a :=> (b -> b)
`op` b
e
              run (Term t :: f (Cxt h f a) i
t) e :: b
e = (Cxt h f a :=> (b -> b)) -> b -> f (Cxt h f a) i -> b
forall (h :: (* -> *) -> * -> *) (a :: * -> *) b.
HFoldable h =>
(a :=> (b -> b)) -> b -> h a :=> b
hfoldr Cxt h f a :=> (b -> b)
run b
e f (Cxt h f a) i
t

    hfoldl :: (b -> a :=> b) -> b -> Cxt h f a :=> b
hfoldl = (b -> a :=> b) -> b -> Cxt h f a i -> b
forall (a :: * -> *) b. (b -> a :=> b) -> b -> Cxt h f a :=> b
hfoldl' where
        hfoldl' :: forall a b. (b -> a :=> b) -> b -> Cxt h f a :=> b
        hfoldl' :: (b -> a :=> b) -> b -> Cxt h f a :=> b
hfoldl' op :: b -> a :=> b
op = b -> Cxt h f a i -> b
b -> Cxt h f a :=> b
run where
              run :: b -> (Cxt h f) a :=> b
              run :: b -> Cxt h f a :=> b
run e :: b
e (Hole a :: a i
a) = b
e b -> a i -> b
b -> a :=> b
`op` a i
a
              run e :: b
e (Term t :: f (Cxt h f a) i
t) = (b -> Cxt h f a :=> b) -> b -> f (Cxt h f a) i -> b
forall (h :: (* -> *) -> * -> *) b (a :: * -> *).
HFoldable h =>
(b -> a :=> b) -> b -> h a :=> b
hfoldl b -> Cxt h f a :=> b
run b
e f (Cxt h f a) i
t

    hfold :: Cxt h f (K m) :=> m
hfold (Hole (K a :: m
a)) = m
a
    hfold (Term t :: f (Cxt h f (K m)) i
t) = (Cxt h f (K m) :=> m) -> f (Cxt h f (K m)) i -> m
forall (h :: (* -> *) -> * -> *) m (a :: * -> *).
(HFoldable h, Monoid m) =>
(a :=> m) -> h a :=> m
hfoldMap Cxt h f (K m) :=> m
forall (h :: (* -> *) -> * -> *) m.
(HFoldable h, Monoid m) =>
h (K m) :=> m
hfold f (Cxt h f (K m)) i
t

    hfoldMap :: (a :=> m) -> Cxt h f a :=> m
hfoldMap = (a :=> m) -> Cxt h f a i -> m
forall m (a :: * -> *). Monoid m => (a :=> m) -> Cxt h f a :=> m
hfoldMap' where
        hfoldMap' :: forall m a. Monoid m => (a :=> m) -> Cxt h f a :=> m
        hfoldMap' :: (a :=> m) -> Cxt h f a :=> m
hfoldMap' f :: a :=> m
f = Cxt h f a i -> m
Cxt h f a :=> m
run where
              run :: Cxt h f a :=> m
              run :: Cxt h f a i -> m
run (Hole a :: a i
a) = a i -> m
a :=> m
f a i
a
              run (Term t :: f (Cxt h f a) i
t) = (Cxt h f a :=> m) -> f (Cxt h f a) i -> m
forall (h :: (* -> *) -> * -> *) m (a :: * -> *).
(HFoldable h, Monoid m) =>
(a :=> m) -> h a :=> m
hfoldMap Cxt h f a :=> m
run f (Cxt h f a) i
t

instance (HTraversable f) => HTraversable (Cxt h f) where
   hmapM :: NatM m a b -> NatM m (Cxt h f a) (Cxt h f b)
hmapM = NatM m a b -> Cxt h f a i -> m (Cxt h f b i)
forall (m :: * -> *) (a :: * -> *) (b :: * -> *).
Monad m =>
NatM m a b -> NatM m (Cxt h f a) (Cxt h f b)
hmapM' where
       hmapM' :: forall m a b. (Monad m) => NatM m a b -> NatM m (Cxt h f a) (Cxt h f b)
       hmapM' :: NatM m a b -> NatM m (Cxt h f a) (Cxt h f b)
hmapM' f :: NatM m a b
f = Cxt h f a i -> m (Cxt h f b i)
NatM m (Cxt h f a) (Cxt h f b)
run where
             run :: NatM m (Cxt h f a) (Cxt h f b)
             run :: Cxt h f a i -> m (Cxt h f b i)
run (Hole x :: a i
x) = (b i -> Cxt Hole f b i) -> m (b i) -> m (Cxt Hole f b i)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM b i -> Cxt Hole f b i
forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole (m (b i) -> m (Cxt h f b i)) -> m (b i) -> m (Cxt h f b i)
forall a b. (a -> b) -> a -> b
$ a i -> m (b i)
NatM m a b
f a i
x
             run (Term t :: f (Cxt h f a) i
t) = (f (Cxt h f b) i -> Cxt h f b i)
-> m (f (Cxt h f b) i) -> m (Cxt h f b i)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM f (Cxt h f b) i -> Cxt h f b i
forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term (m (f (Cxt h f b) i) -> m (Cxt h f b i))
-> m (f (Cxt h f b) i) -> m (Cxt h f b i)
forall a b. (a -> b) -> a -> b
$ NatM m (Cxt h f a) (Cxt h f b)
-> f (Cxt h f a) i -> m (f (Cxt h f b) i)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Monad m) =>
NatM m a b -> NatM m (t a) (t b)
hmapM NatM m (Cxt h f a) (Cxt h f b)
run f (Cxt h f a) i
t
   htraverse :: NatM f a b -> NatM f (Cxt h f a) (Cxt h f b)
htraverse f :: NatM f a b
f (Hole x :: a i
x) = b i -> Cxt Hole f b i
forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole (b i -> Cxt Hole f b i) -> f (b i) -> f (Cxt Hole f b i)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a i -> f (b i)
NatM f a b
f a i
x
   htraverse f :: NatM f a b
f (Term t :: f (Cxt h f a) i
t) = f (Cxt h f b) i -> Cxt h f b i
forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term (f (Cxt h f b) i -> Cxt h f b i)
-> f (f (Cxt h f b) i) -> f (Cxt h f b i)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NatM f (Cxt h f a) (Cxt h f b)
-> f (Cxt h f a) i -> f (f (Cxt h f b) i)
forall (t :: (* -> *) -> * -> *) (f :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Applicative f) =>
NatM f a b -> NatM f (t a) (t b)
htraverse (NatM f a b -> NatM f (Cxt h f a) (Cxt h f b)
forall (t :: (* -> *) -> * -> *) (f :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Applicative f) =>
NatM f a b -> NatM f (t a) (t b)
htraverse NatM f a b
f) f (Cxt h f a) i
t

simpCxt :: (HFunctor f) => f a i -> Context f a i
simpCxt :: f a i -> Context f a i
simpCxt = f (Cxt Hole f a) i -> Context f a i
forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term (f (Cxt Hole f a) i -> Context f a i)
-> (f a i -> f (Cxt Hole f a) i) -> f a i -> Context f a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a :-> Cxt Hole f a) -> f a :-> f (Cxt Hole f a)
forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap a :-> Cxt Hole f a
forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole

{-| Cast a term over a signature to a context over the same signature. -}
toCxt :: (HFunctor f) => HFix f :-> Context f a
{-# INLINE toCxt #-}
toCxt :: HFix f :-> Context f a
toCxt = HFix f i -> Context f a i
forall a b. a -> b
unsafeCoerce
-- equivalentto @Term . (hfmap toCxt) . unTerm@