{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Cubix.Sin.Compdata.Annotation (
    Annotated(..)
  , getAnn
  , MonadAnnotater(..)
  , AnnotateDefault
  , pattern AnnotateDefault
  , runAnnotateDefault
  , annotateM
  , propAnnSigFun
  ) where

import Control.Monad.Identity ( Identity(..) )
import Data.Default ( Default(..) )
import Data.Comp.Multi ( Cxt(..), (:=>), CxtFunM, SigFun, appSigFunM, HFix, AnnHFix )
import Data.Comp.Multi.HTraversable ( HTraversable )
import Data.Comp.Multi.Ops ((:&:)(..), Sum (..), contract)

import Cubix.Sin.Compdata.Instances ()
import Data.Type.Equality

---- This exists so you can constrain a functor to be annotated without also
---- naming the unannotated functor. This makes it more convenient for inclusion
---- in constraint synonyms
class Annotated (f :: (* -> *) -> * -> *) a | f -> a where
  getAnn' :: f e l -> a

instance Annotated (f :&: a) a where
  getAnn' :: (:&:) f a e l -> a
getAnn' (_ :&: x :: a
x) = a
x

instance ( Annotated f a
         , Annotated (Sum fs) a
         ) => Annotated (Sum (f ': fs)) a where
  getAnn' :: Sum (f : fs) e l -> a
getAnn' (Sum w :: Elem f (f : fs)
w a :: f e l
a) = case Elem f (f : fs) -> Either (f :~: f) (Elem f fs)
forall k (f :: k) (g :: k) (fs :: [k]).
Elem f (g : fs) -> Either (f :~: g) (Elem f fs)
contract Elem f (f : fs)
w of
    Left Refl -> f e l -> a
forall (f :: (* -> *) -> * -> *) a (e :: * -> *) l.
Annotated f a =>
f e l -> a
getAnn' f e l
a
    Right w0 :: Elem f fs
w0  -> Sum fs e l -> a
forall (fs :: [(* -> *) -> * -> *]) a (e :: * -> *) l.
Annotated (Sum fs) a =>
Sum fs e l -> a
go (Elem f fs -> f e l -> Sum fs e l
forall (f :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *])
       (h :: * -> *) e.
Elem f fs -> f h e -> Sum fs h e
Sum Elem f fs
w0 f e l
a)

    where go :: (Annotated (Sum fs) a) => Sum fs e l -> a
          go :: Sum fs e l -> a
go = Sum fs e l -> a
forall (f :: (* -> *) -> * -> *) a (e :: * -> *) l.
Annotated f a =>
f e l -> a
getAnn'


getAnn :: (Annotated f a) => HFix f :=> a
getAnn :: HFix f :=> a
getAnn (Term x :: f (HFix f) i
x) = f (HFix f) i -> a
forall (f :: (* -> *) -> * -> *) a (e :: * -> *) l.
Annotated f a =>
f e l -> a
getAnn' f (HFix f) i
x

class (Monad m) => MonadAnnotater a m where
  annM :: forall f (e :: * -> *) l. f e l -> m ((f :&: a) e l)

newtype AnnotateDefault a x = AnnotateDefault' { AnnotateDefault a x -> Identity x
runAnnotateDefault' :: Identity x}
  deriving ( a -> AnnotateDefault a b -> AnnotateDefault a a
(a -> b) -> AnnotateDefault a a -> AnnotateDefault a b
(forall a b.
 (a -> b) -> AnnotateDefault a a -> AnnotateDefault a b)
-> (forall a b. a -> AnnotateDefault a b -> AnnotateDefault a a)
-> Functor (AnnotateDefault a)
forall a b. a -> AnnotateDefault a b -> AnnotateDefault a a
forall a b. (a -> b) -> AnnotateDefault a a -> AnnotateDefault a b
forall a a b. a -> AnnotateDefault a b -> AnnotateDefault a a
forall a a b.
(a -> b) -> AnnotateDefault a a -> AnnotateDefault a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> AnnotateDefault a b -> AnnotateDefault a a
$c<$ :: forall a a b. a -> AnnotateDefault a b -> AnnotateDefault a a
fmap :: (a -> b) -> AnnotateDefault a a -> AnnotateDefault a b
$cfmap :: forall a a b.
(a -> b) -> AnnotateDefault a a -> AnnotateDefault a b
Functor, Functor (AnnotateDefault a)
a -> AnnotateDefault a a
Functor (AnnotateDefault a) =>
(forall a. a -> AnnotateDefault a a)
-> (forall a b.
    AnnotateDefault a (a -> b)
    -> AnnotateDefault a a -> AnnotateDefault a b)
-> (forall a b c.
    (a -> b -> c)
    -> AnnotateDefault a a
    -> AnnotateDefault a b
    -> AnnotateDefault a c)
-> (forall a b.
    AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b)
-> (forall a b.
    AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a a)
-> Applicative (AnnotateDefault a)
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a a
AnnotateDefault a (a -> b)
-> AnnotateDefault a a -> AnnotateDefault a b
(a -> b -> c)
-> AnnotateDefault a a
-> AnnotateDefault a b
-> AnnotateDefault a c
forall a. Functor (AnnotateDefault a)
forall a. a -> AnnotateDefault a a
forall a a. a -> AnnotateDefault a a
forall a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a a
forall a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
forall a b.
AnnotateDefault a (a -> b)
-> AnnotateDefault a a -> AnnotateDefault a b
forall a a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a a
forall a a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
forall a a b.
AnnotateDefault a (a -> b)
-> AnnotateDefault a a -> AnnotateDefault a b
forall a b c.
(a -> b -> c)
-> AnnotateDefault a a
-> AnnotateDefault a b
-> AnnotateDefault a c
forall a a b c.
(a -> b -> c)
-> AnnotateDefault a a
-> AnnotateDefault a b
-> AnnotateDefault a c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a a
$c<* :: forall a a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a a
*> :: AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
$c*> :: forall a a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
liftA2 :: (a -> b -> c)
-> AnnotateDefault a a
-> AnnotateDefault a b
-> AnnotateDefault a c
$cliftA2 :: forall a a b c.
(a -> b -> c)
-> AnnotateDefault a a
-> AnnotateDefault a b
-> AnnotateDefault a c
<*> :: AnnotateDefault a (a -> b)
-> AnnotateDefault a a -> AnnotateDefault a b
$c<*> :: forall a a b.
AnnotateDefault a (a -> b)
-> AnnotateDefault a a -> AnnotateDefault a b
pure :: a -> AnnotateDefault a a
$cpure :: forall a a. a -> AnnotateDefault a a
$cp1Applicative :: forall a. Functor (AnnotateDefault a)
Applicative, Applicative (AnnotateDefault a)
a -> AnnotateDefault a a
Applicative (AnnotateDefault a) =>
(forall a b.
 AnnotateDefault a a
 -> (a -> AnnotateDefault a b) -> AnnotateDefault a b)
-> (forall a b.
    AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b)
-> (forall a. a -> AnnotateDefault a a)
-> Monad (AnnotateDefault a)
AnnotateDefault a a
-> (a -> AnnotateDefault a b) -> AnnotateDefault a b
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
forall a. Applicative (AnnotateDefault a)
forall a. a -> AnnotateDefault a a
forall a a. a -> AnnotateDefault a a
forall a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
forall a b.
AnnotateDefault a a
-> (a -> AnnotateDefault a b) -> AnnotateDefault a b
forall a a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
forall a a b.
AnnotateDefault a a
-> (a -> AnnotateDefault a b) -> AnnotateDefault a b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> AnnotateDefault a a
$creturn :: forall a a. a -> AnnotateDefault a a
>> :: AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
$c>> :: forall a a b.
AnnotateDefault a a -> AnnotateDefault a b -> AnnotateDefault a b
>>= :: AnnotateDefault a a
-> (a -> AnnotateDefault a b) -> AnnotateDefault a b
$c>>= :: forall a a b.
AnnotateDefault a a
-> (a -> AnnotateDefault a b) -> AnnotateDefault a b
$cp1Monad :: forall a. Applicative (AnnotateDefault a)
Monad )

pattern AnnotateDefault :: x -> AnnotateDefault a x
pattern $bAnnotateDefault :: x -> AnnotateDefault a x
$mAnnotateDefault :: forall r x a. AnnotateDefault a x -> (x -> r) -> (Void# -> r) -> r
AnnotateDefault x = AnnotateDefault' (Identity x)

runAnnotateDefault :: AnnotateDefault a (AnnHFix a f l) -> AnnHFix a f l
runAnnotateDefault :: AnnotateDefault a (AnnHFix a f l) -> AnnHFix a f l
runAnnotateDefault = Identity (AnnHFix a f l) -> AnnHFix a f l
forall a. Identity a -> a
runIdentity (Identity (AnnHFix a f l) -> AnnHFix a f l)
-> (AnnotateDefault a (AnnHFix a f l) -> Identity (AnnHFix a f l))
-> AnnotateDefault a (AnnHFix a f l)
-> AnnHFix a f l
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AnnotateDefault a (AnnHFix a f l) -> Identity (AnnHFix a f l)
forall a x. AnnotateDefault a x -> Identity x
runAnnotateDefault'

-- | Specializing annotation to Maybe a to aid instance selection
instance MonadAnnotater (Maybe a) (AnnotateDefault a) where
  annM :: f e l -> AnnotateDefault a ((:&:) f (Maybe a) e l)
annM x :: f e l
x = (:&:) f (Maybe a) e l -> AnnotateDefault a ((:&:) f (Maybe a) e l)
forall (m :: * -> *) a. Monad m => a -> m a
return (f e l
x f e l -> Maybe a -> (:&:) f (Maybe a) e l
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: Maybe a
forall a. Default a => a
def)

annotateM :: (HTraversable f, MonadAnnotater a m) => CxtFunM m f (f :&: a)
annotateM :: CxtFunM m f (f :&: a)
annotateM = SigFunM m f (f :&: a) -> CxtFunM m f (f :&: a)
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *).
(HTraversable f, Monad m) =>
SigFunM m f g -> CxtFunM m f g
appSigFunM forall a (m :: * -> *) (f :: (* -> *) -> * -> *) (e :: * -> *) l.
MonadAnnotater a m =>
f e l -> m ((:&:) f a e l)
SigFunM m f (f :&: a)
annM

propAnnSigFun :: SigFun f g -> SigFun (f :&: a) (g :&: a)
propAnnSigFun :: SigFun f g -> SigFun (f :&: a) (g :&: a)
propAnnSigFun f :: SigFun f g
f (t :: f a i
t :&: a :: a
a) = (f a i -> g a i
SigFun f g
f f a i
t) g a i -> a -> (:&:) g a a i
forall k (f :: (* -> *) -> k -> *) a (g :: * -> *) (e :: k).
f g e -> a -> (:&:) f a g e
:&: a
a