{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Multi.Annotation
-- Copyright   :  (c) 2011 Patrick Bahr
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- This module defines annotations on signatures. All definitions are
-- generalised versions of those in "Data.Comp.Annotation".
--
--------------------------------------------------------------------------------

module Data.Comp.Multi.Annotation
    (
     AnnTerm,
     AnnHFix,
     (:&:) (..),
     RemA (..),
     liftA,
     ann,
     liftA',
     stripA,
     propAnn,
     project',
     isNode',
     inj',
     inject',
     injectOpt,
     caseH',
     caseCxt',
     caseCxt'',
     DistAnn,
     AnnCxt,
     AnnCxtS
    ) where

import Data.Proxy ( Proxy )
import Data.Comp.Dict
import Data.Comp.Elem
import Data.Comp.Multi.Algebra
import Data.Comp.Multi.HFunctor
import Data.Comp.Multi.Ops
import Data.Comp.Multi.Sum
import Data.Comp.Multi.Term

type AnnHFix a f = HFix (f :&: a)
type AnnTerm a fs = HFix (Sum fs :&: a)

type AnnCxt p h f a = Cxt h (f :&: p) a
-- type AnnContext p f a = AnnCxt p Hole f a

type AnnCxtS p h fs a = AnnCxt p h (Sum fs) a
-- type AnnContextS p fs a = AnnContext p (Sum fs) a

-- | This function transforms a function with a domain constructed
-- from a functor to a function with a domain constructed with the
-- same functor but with an additional annotation.
liftA :: (RemA s s') => (s' a :-> t) -> s a :-> t
liftA :: (s' a :-> t) -> s a :-> t
liftA f :: s' a :-> t
f v :: s a i
v = s' a i -> t i
s' a :-> t
f (s a i -> s' a i
forall (s :: (* -> *) -> * -> *) (s' :: (* -> *) -> * -> *)
       (a :: * -> *).
RemA s s' =>
s a :-> s' a
remA s a i
v)

-- | This function annotates each sub term of the given term with the
-- given value (of type a).

ann :: (HFunctor f) => p -> CxtFun f (f :&: p)
ann :: p -> CxtFun f (f :&: p)
ann c :: p
c = SigFun f (f :&: p) -> CxtFun f (f :&: p)
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *).
HFunctor f =>
SigFun f g -> CxtFun f g
appSigFun (f a i -> p -> (:&:) f p a i
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: p
c)

-- | This function transforms a function with a domain constructed
-- from a functor to a function with a domain constructed with the
-- same functor but with an additional annotation.
liftA' :: (HFunctor s)
       => (s a :-> Cxt h s a) -> (s :&: p) a :-> Cxt h (s :&: p) a
liftA' :: (s a :-> Cxt h s a) -> (:&:) s p a :-> Cxt h (s :&: p) a
liftA' f :: s a :-> Cxt h s a
f (v' :: s a i
v' :&: p :: p
p) = p -> Cxt h s a i -> Cxt h (s :&: p) a i
forall (f :: (* -> *) -> * -> *) p.
HFunctor f =>
p -> CxtFun f (f :&: p)
ann p
p (s a i -> Cxt h s a i
s a :-> Cxt h s a
f s a i
v')

{-| This function strips the annotations from a term over a
functor with annotations. -}

stripA :: (RemA g f, HFunctor g) => CxtFun g f
stripA :: CxtFun g f
stripA = SigFun g f -> CxtFun g f
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *).
HFunctor f =>
SigFun f g -> CxtFun f g
appSigFun SigFun g f
forall (s :: (* -> *) -> * -> *) (s' :: (* -> *) -> * -> *)
       (a :: * -> *).
RemA s s' =>
s a :-> s' a
remA


propAnn :: (HFunctor g) => Hom f g -> Hom (f :&: p) (g :&: p)
propAnn :: Hom f g -> Hom (f :&: p) (g :&: p)
propAnn alg :: Hom f g
alg (f :: f a i
f :&: p :: p
p) = p -> Cxt Hole g a i -> Cxt Hole (g :&: p) a i
forall (f :: (* -> *) -> * -> *) p.
HFunctor f =>
p -> CxtFun f (f :&: p)
ann p
p (f a i -> Cxt Hole g a i
Hom f g
alg f a i
f)

-- | This function is similar to 'project' but applies to signatures
-- with an annotation which is then ignored.
project' :: (RemA f f', s :<: f') => Cxt h f a i -> Maybe (s (Cxt h f a) i)
project' :: Cxt h f a i -> Maybe (s (Cxt h f a) i)
project' (Term x :: f (Cxt h f a) i
x) = f' (Cxt h f a) i -> Maybe (s (Cxt h f a) i)
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (a :: * -> *).
(f :<: g) =>
NatM Maybe (g a) (f a)
proj (f' (Cxt h f a) i -> Maybe (s (Cxt h f a) i))
-> f' (Cxt h f a) i -> Maybe (s (Cxt h f a) i)
forall a b. (a -> b) -> a -> b
$ f (Cxt h f a) i -> f' (Cxt h f a) i
forall (s :: (* -> *) -> * -> *) (s' :: (* -> *) -> * -> *)
       (a :: * -> *).
RemA s s' =>
s a :-> s' a
remA f (Cxt h f a) i
x
project' _ = Maybe (s (Cxt h f a) i)
forall a. Maybe a
Nothing

isNode' :: (HFunctor g, RemA g g', f :<: g') => Proxy f -> Cxt h g a l -> Bool
isNode' :: Proxy f -> Cxt h g a l -> Bool
isNode' p :: Proxy f
p t :: Cxt h g a l
t = Proxy f -> Cxt h g' a l -> Bool
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *) h
       (a :: * -> *) l.
(f :<: g) =>
Proxy f -> Cxt h g a l -> Bool
isNode Proxy f
p (Cxt h g' a l -> Bool) -> Cxt h g' a l -> Bool
forall a b. (a -> b) -> a -> b
$ Cxt h g a l -> Cxt h g' a l
forall (g :: (* -> *) -> * -> *) (f :: (* -> *) -> * -> *).
(RemA g f, HFunctor g) =>
CxtFun g f
stripA Cxt h g a l
t

inj' :: (f :<: g) => (f :&: p) e l -> (g :&: p) e l
inj' :: (:&:) f p e l -> (:&:) g p e l
inj' (x :: f e l
x :&: p :: p
p) = (f e l -> g e l
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (a :: * -> *).
(f :<: g) =>
f a :-> g a
inj f e l
x) g e l -> p -> (:&:) g p e l
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: p
p

inject' :: (f :<: g) => (f :&: p) (Cxt h (g :&: p) a) :-> Cxt h (g :&: p) a
inject' :: (:&:) f p (Cxt h (g :&: p) a) :-> Cxt h (g :&: p) a
inject' = (:&:) g p (Cxt h (g :&: p) a) i -> Cxt h (g :&: p) a i
forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term ((:&:) g p (Cxt h (g :&: p) a) i -> Cxt h (g :&: p) a i)
-> ((:&:) f p (Cxt h (g :&: p) a) i
    -> (:&:) g p (Cxt h (g :&: p) a) i)
-> (:&:) f p (Cxt h (g :&: p) a) i
-> Cxt h (g :&: p) a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:&:) f p (Cxt h (g :&: p) a) i -> (:&:) g p (Cxt h (g :&: p) a) i
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *) p
       (e :: * -> *) l.
(f :<: g) =>
(:&:) f p e l -> (:&:) g p e l
inj'

injectOpt :: (f :<: g) => f (AnnHFix (Maybe p) g) l -> AnnHFix (Maybe p) g l
injectOpt :: f (AnnHFix (Maybe p) g) l -> AnnHFix (Maybe p) g l
injectOpt t :: f (AnnHFix (Maybe p) g) l
t = (:&:) f (Maybe p) (AnnHFix (Maybe p) g) l -> AnnHFix (Maybe p) g l
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *) p h
       (a :: * -> *).
(f :<: g) =>
(:&:) f p (Cxt h (g :&: p) a) :-> Cxt h (g :&: p) a
inject' (f (AnnHFix (Maybe p) g) l
t f (AnnHFix (Maybe p) g) l
-> Maybe p -> (:&:) f (Maybe p) (AnnHFix (Maybe p) g) l
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: Maybe p
forall a. Maybe a
Nothing)

caseH' :: forall fs a e l t. Alts (DistAnn fs a) e l t -> (Sum fs :&: a) e l -> t
caseH' :: Alts (DistAnn fs a) e l t -> (:&:) (Sum fs) a e l -> t
caseH' alts :: Alts (DistAnn fs a) e l t
alts = Alts (DistAnn fs a) e l t -> Sum (DistAnn fs a) e l -> t
forall (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
Alts fs a e b -> Sum fs a e -> b
caseH Alts (DistAnn fs a) e l t
alts (Sum (DistAnn fs a) e l -> t)
-> ((:&:) (Sum fs) a e l -> Sum (DistAnn fs a) e l)
-> (:&:) (Sum fs) a e l
-> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:&:) (Sum fs) a e l -> Sum (DistAnn fs a) e l
forall (fs :: [(* -> *) -> * -> *]) a (e :: * -> *).
(:&:) (Sum fs) a e :-> Sum (DistAnn fs a) e
distAnn

caseCxt' :: forall cxt fs a e l t. (All cxt fs) => Proxy cxt -> (forall f. (cxt f) => (f :&: a) e l -> t) -> (Sum fs :&: a) e l -> t
caseCxt' :: Proxy cxt
-> (forall (f :: (* -> *) -> * -> *). cxt f => (:&:) f a e l -> t)
-> (:&:) (Sum fs) a e l
-> t
caseCxt' _ f :: forall (f :: (* -> *) -> * -> *). cxt f => (:&:) f a e l -> t
f (Sum wit :: Elem f fs
wit v :: f e l
v :&: a :: a
a) =
  (:&:) f a e l -> t
forall (f :: (* -> *) -> * -> *). cxt f => (:&:) f a e l -> t
f (f e l
v f e l -> a -> (:&:) f a e l
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: a
a) (cxt f => t) -> Dict cxt f -> t
forall k (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ Elem f fs -> Dict cxt f
forall k (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
dictFor @cxt Elem f fs
wit

caseCxt'' :: forall cxt fs a e l t. (All cxt (DistAnn fs a)) => Proxy cxt -> (forall f. (cxt (f :&: a)) => (f :&: a) e l -> t) -> (Sum fs :&: a) e l -> t
caseCxt'' :: Proxy cxt
-> (forall (f :: (* -> *) -> * -> *).
    cxt (f :&: a) =>
    (:&:) f a e l -> t)
-> (:&:) (Sum fs) a e l
-> t
caseCxt'' _ f :: forall (f :: (* -> *) -> * -> *).
cxt (f :&: a) =>
(:&:) f a e l -> t
f (Sum wit :: Elem f fs
wit v :: f e l
v :&: a :: a
a) =
  (:&:) f a e l -> t
forall (f :: (* -> *) -> * -> *).
cxt (f :&: a) =>
(:&:) f a e l -> t
f (f e l
v f e l -> a -> (:&:) f a e l
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: a
a) (cxt (f :&: a) => t) -> Dict cxt (f :&: a) -> t
forall k (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ Elem (f :&: a) (DistAnn fs a) -> Dict cxt (f :&: a)
forall k (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
dictFor @cxt (Elem f fs -> Elem (f :&: a) (DistAnn fs a)
forall (f :: (* -> *) -> * -> *).
Elem f fs -> Elem (f :&: a) (DistAnn fs a)
annWit Elem f fs
wit)

  where annWit :: Elem f fs -> Elem (f :&: a) (DistAnn fs a)
        annWit :: Elem f fs -> Elem (f :&: a) (DistAnn fs a)
annWit = Elem f fs -> Elem (f :&: a) (DistAnn fs a)
forall k1 k2 (f :: k1) (fs :: [k1]) (g :: k2) (gs :: [k2]).
Elem f fs -> Elem g gs
unsafeElem


type family DistAnn (fs :: [(* -> *) -> * -> *]) (a :: *) :: [(* -> *) -> * -> *] where
  DistAnn (f ': fs) a = f :&: a ': DistAnn fs a
  DistAnn '[]       _ = '[]

distAnn :: (Sum fs :&: a) e :-> Sum (DistAnn fs a) e
distAnn :: (:&:) (Sum fs) a e i -> Sum (DistAnn fs a) e i
distAnn (Sum wit :: Elem f fs
wit v :: f e i
v :&: a :: a
a) =
  Elem f fs
-> f e i -> (f e :-> (:&:) f a e) -> Sum (DistAnn fs a) e i
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (a :: * -> *) e (g :: (* -> *) -> * -> *)
       (gs :: [(* -> *) -> * -> *]).
Elem f fs -> f a e -> (f a :-> g a) -> Sum gs a e
unsafeMapSum Elem f fs
wit f e i
v (f e i -> a -> (:&:) f a e i
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: a
a)