module Data.Comp.Trans.DeriveMulti (
    deriveMulti
  ) where

import Control.Lens ( _1, _2, _3, (&), (%~), (%%~), (^.), view )
import Control.Monad ( liftM )
import Control.Monad.Trans ( MonadTrans(lift) )


import Language.Haskell.TH.Syntax hiding ( lift )
import Language.Haskell.TH.ExpandSyns ( expandSyns )

import Data.Comp.Trans.Util

deriveMulti :: Name -> CompTrans [Dec]
deriveMulti :: Name -> CompTrans [Dec]
deriveMulti n :: Name
n = do
  Info
inf <- ReaderT TransCtx Q Info -> CompTrans Info
forall a. ReaderT TransCtx Q a -> CompTrans a
CompTrans (ReaderT TransCtx Q Info -> CompTrans Info)
-> ReaderT TransCtx Q Info -> CompTrans Info
forall a b. (a -> b) -> a -> b
$ Q Info -> ReaderT TransCtx Q Info
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Info -> ReaderT TransCtx Q Info)
-> Q Info -> ReaderT TransCtx Q Info
forall a b. (a -> b) -> a -> b
$ Name -> Q Info
reify Name
n
  Map Name Type
substs <- Getting (Map Name Type) TransCtx (Map Name Type)
-> CompTrans (Map Name Type)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting (Map Name Type) TransCtx (Map Name Type)
forall c. HasTransCtx c => Lens' c (Map Name Type)
substitutions
  [Name]
typeArgs <- Name -> CompTrans [Name]
getTypeArgs Name
n

  if Map Name Type -> [Name] -> Bool
forall a b. Ord a => Map a b -> [a] -> Bool
containsAll Map Name Type
substs [Name]
typeArgs then
    case Info
inf of
      TyConI (DataD _ nm :: Name
nm _ _ cons :: [Con]
cons _)   -> Name -> [Con] -> CompTrans [Dec]
mkGADT Name
nm (Map Name Type -> [Con] -> [Con]
forall x. Data x => Map Name Type -> x -> x
applySubsts Map Name Type
substs [Con]
cons)
      TyConI (NewtypeD _ nm :: Name
nm _ _ con :: Con
con _) -> Name -> [Con] -> CompTrans [Dec]
mkGADT Name
nm [(Map Name Type -> Con -> Con
forall x. Data x => Map Name Type -> x -> x
applySubsts Map Name Type
substs Con
con)]
      _                              -> do ReaderT TransCtx Q () -> CompTrans ()
forall a. ReaderT TransCtx Q a -> CompTrans a
CompTrans (ReaderT TransCtx Q () -> CompTrans ())
-> ReaderT TransCtx Q () -> CompTrans ()
forall a b. (a -> b) -> a -> b
$ Q () -> ReaderT TransCtx Q ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q () -> ReaderT TransCtx Q ()) -> Q () -> ReaderT TransCtx Q ()
forall a b. (a -> b) -> a -> b
$ String -> Q ()
reportError (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ "Attempted to derive multi-sorted compositional data type for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n
                                                              String -> String -> String
forall a. [a] -> [a] -> [a]
++ ", which is not a nullary datatype (and does not have concrete values supplied for type args)"
                                           [Dec] -> CompTrans [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return []
   else
    do ReaderT TransCtx Q () -> CompTrans ()
forall a. ReaderT TransCtx Q a -> CompTrans a
CompTrans (ReaderT TransCtx Q () -> CompTrans ())
-> ReaderT TransCtx Q () -> CompTrans ()
forall a b. (a -> b) -> a -> b
$ Q () -> ReaderT TransCtx Q ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q () -> ReaderT TransCtx Q ()) -> Q () -> ReaderT TransCtx Q ()
forall a b. (a -> b) -> a -> b
$ String -> Q ()
reportError (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ "Attempted to derive multi-sorted compositional data type for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n
                            String -> String -> String
forall a. [a] -> [a] -> [a]
++ " but it has type arguments which are not substituted away"
       [Dec] -> CompTrans [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return []

checkUniqueVar :: Con -> CompTrans ()
checkUniqueVar :: Con -> CompTrans ()
checkUniqueVar con :: Con
con = if [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Type -> Bool) -> [Type] -> [Type]
forall a. (a -> Bool) -> [a] -> [a]
filter Type -> Bool
isVar [Type]
fields) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 then
                       () -> CompTrans ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                     else
                       String -> CompTrans ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> CompTrans ()) -> String -> CompTrans ()
forall a b. (a -> b) -> a -> b
$ "comptrans: Multiple annotion fields in constructor:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Show a => a -> String
show Con
con
  where
    fields :: [Type]
    fields :: [Type]
fields = case Con
con of
      RecC _ sts :: [VarBangType]
sts    -> (VarBangType -> Type) -> [VarBangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (VarBangType -> Getting Type VarBangType Type -> Type
forall s a. s -> Getting a s a -> a
^. Getting Type VarBangType Type
forall s t a b. Field3 s t a b => Lens s t a b
_3) [VarBangType]
sts
      NormalC _ sts :: [BangType]
sts -> (BangType -> Type) -> [BangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (BangType -> Getting Type BangType Type -> Type
forall s a. s -> Getting a s a -> a
^. Getting Type BangType Type
forall s t a b. Field2 s t a b => Lens s t a b
_2) [BangType]
sts
      _             -> String -> [Type]
forall a. HasCallStack => String -> a
error (String -> [Type]) -> String -> [Type]
forall a b. (a -> b) -> a -> b
$ "Attempted to derive multi-sorted compositional datatype for something with non-normal constructors: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Show a => a -> String
show Con
con

mkGADT :: Name -> [Con] -> CompTrans [Dec]
mkGADT :: Name -> [Con] -> CompTrans [Dec]
mkGADT n :: Name
n cons :: [Con]
cons = do
  Name
e <- ReaderT TransCtx Q Name -> CompTrans Name
forall a. ReaderT TransCtx Q a -> CompTrans a
CompTrans (ReaderT TransCtx Q Name -> CompTrans Name)
-> ReaderT TransCtx Q Name -> CompTrans Name
forall a b. (a -> b) -> a -> b
$ Q Name -> ReaderT TransCtx Q Name
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Name -> ReaderT TransCtx Q Name)
-> Q Name -> ReaderT TransCtx Q Name
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName "e"
  Name
i <- ReaderT TransCtx Q Name -> CompTrans Name
forall a. ReaderT TransCtx Q a -> CompTrans a
CompTrans (ReaderT TransCtx Q Name -> CompTrans Name)
-> ReaderT TransCtx Q Name -> CompTrans Name
forall a b. (a -> b) -> a -> b
$ Q Name -> ReaderT TransCtx Q Name
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Name -> ReaderT TransCtx Q Name)
-> Q Name -> ReaderT TransCtx Q Name
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName "i"
  let n' :: Name
n' = Name -> Name
transName Name
n
  Maybe AnnotationPropInfo
annProp <- Getting
  (Maybe AnnotationPropInfo) TransCtx (Maybe AnnotationPropInfo)
-> CompTrans (Maybe AnnotationPropInfo)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting
  (Maybe AnnotationPropInfo) TransCtx (Maybe AnnotationPropInfo)
forall c. HasTransCtx c => Lens' c (Maybe AnnotationPropInfo)
annotationProp
  case Maybe AnnotationPropInfo
annProp of
    Just _annPropInf :: AnnotationPropInfo
_annPropInf  -> (Con -> CompTrans ()) -> [Con] -> CompTrans ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Con -> CompTrans ()
checkUniqueVar [Con]
cons
    Nothing          -> () -> CompTrans ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  [Con]
cons' <- (Con -> CompTrans Con) -> [Con] -> CompTrans [Con]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Name -> Name -> Name -> Con -> CompTrans Con
mkCon Name
n' Name
e Name
i) [Con]
cons
  [Dec] -> CompTrans [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> CompTrans [Dec]) -> [Dec] -> CompTrans [Dec]
forall a b. (a -> b) -> a -> b
$ [[Type]
-> Name
-> [TyVarBndr]
-> Maybe Type
-> [Con]
-> [DerivClause]
-> Dec
DataD [] Name
n' [Name -> Type -> TyVarBndr
KindedTV Name
e (Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
StarT) Type
StarT), Name -> TyVarBndr
PlainTV Name
i] Maybe Type
forall a. Maybe a
Nothing [Con]
cons' []
           ,[Type]
-> Name
-> [TyVarBndr]
-> Maybe Type
-> [Con]
-> [DerivClause]
-> Dec
DataD [] (Name -> Name
nameLab Name
n) [] Maybe Type
forall a. Maybe a
Nothing [] []
           ]

mkCon :: Name -> Name -> Name -> Con -> CompTrans Con
mkCon :: Name -> Name -> Name -> Con -> CompTrans Con
mkCon l :: Name
l e :: Name
e i :: Name
i (NormalC n :: Name
n sts :: [BangType]
sts) = Getting
  (Maybe AnnotationPropInfo) TransCtx (Maybe AnnotationPropInfo)
-> CompTrans (Maybe AnnotationPropInfo)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting
  (Maybe AnnotationPropInfo) TransCtx (Maybe AnnotationPropInfo)
forall c. HasTransCtx c => Lens' c (Maybe AnnotationPropInfo)
annotationProp CompTrans (Maybe AnnotationPropInfo)
-> (Maybe AnnotationPropInfo -> CompTrans Con) -> CompTrans Con
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe AnnotationPropInfo -> CompTrans Con
forall a. HasAnnotationPropInfo a => Maybe a -> CompTrans Con
mkConNormal
   where
     mkConNormal :: Maybe a -> CompTrans Con
mkConNormal annPropInfo :: Maybe a
annPropInfo = [TyVarBndr] -> [Type] -> Con -> Con
ForallC [] [Type]
ctx (Con -> Con) -> CompTrans Con -> CompTrans Con
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompTrans Con
inner
      where
        ctx :: [Type]
ctx = [(Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
EqualityT [(Name -> Type
VarT Name
i), (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ Name -> Name
nameLab Name
l)]]

        sts' :: [BangType]
sts'  = case Maybe a
annPropInfo of
                  Just api -> (BangType -> Bool) -> [BangType] -> [BangType]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (BangType -> Bool) -> BangType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
api a -> Getting (Type -> Bool) a (Type -> Bool) -> Type -> Bool
forall s a. s -> Getting a s a -> a
^. Getting (Type -> Bool) a (Type -> Bool)
forall c. HasAnnotationPropInfo c => Lens' c (Type -> Bool)
isAnnotation) (Type -> Bool) -> (BangType -> Type) -> BangType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BangType -> Getting Type BangType Type -> Type
forall s a. s -> Getting a s a -> a
^. Getting Type BangType Type
forall s t a b. Field2 s t a b => Lens s t a b
_2)) [BangType]
sts
                  Nothing  -> [BangType]
sts
        sts'' :: CompTrans [BangType]
sts'' = [BangType]
sts' [BangType]
-> ([BangType] -> CompTrans [BangType]) -> CompTrans [BangType]
forall a b. a -> (a -> b) -> b
& ((BangType -> CompTrans BangType)
-> [BangType] -> CompTrans [BangType]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse((BangType -> CompTrans BangType)
 -> [BangType] -> CompTrans [BangType])
-> ((Type -> CompTrans Type) -> BangType -> CompTrans BangType)
-> (Type -> CompTrans Type)
-> [BangType]
-> CompTrans [BangType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Type -> CompTrans Type) -> BangType -> CompTrans BangType
forall s t a b. Field2 s t a b => Lens s t a b
_2) ((Type -> CompTrans Type) -> [BangType] -> CompTrans [BangType])
-> (Type -> CompTrans Type) -> [BangType] -> CompTrans [BangType]
forall k (f :: k -> *) s (t :: k) a (b :: k).
LensLike f s t a b -> LensLike f s t a b
%%~ Name -> Type -> CompTrans Type
unfixType Name
e
        inner :: CompTrans Con
inner = ([BangType] -> Con) -> CompTrans [BangType] -> CompTrans Con
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (Name -> [BangType] -> Con
NormalC (Name -> Name
transName Name
n)) CompTrans [BangType]
sts''
mkCon l :: Name
l e :: Name
e i :: Name
i (RecC n :: Name
n vsts :: [VarBangType]
vsts) = Getting
  (Maybe AnnotationPropInfo) TransCtx (Maybe AnnotationPropInfo)
-> CompTrans (Maybe AnnotationPropInfo)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting
  (Maybe AnnotationPropInfo) TransCtx (Maybe AnnotationPropInfo)
forall c. HasTransCtx c => Lens' c (Maybe AnnotationPropInfo)
annotationProp CompTrans (Maybe AnnotationPropInfo)
-> (Maybe AnnotationPropInfo -> CompTrans Con) -> CompTrans Con
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe AnnotationPropInfo -> CompTrans Con
forall a. HasAnnotationPropInfo a => Maybe a -> CompTrans Con
mkConRec
  where
    mkConRec :: Maybe a -> CompTrans Con
mkConRec annPropInfo :: Maybe a
annPropInfo = [TyVarBndr] -> [Type] -> Con -> Con
ForallC [] [Type]
ctx (Con -> Con) -> CompTrans Con -> CompTrans Con
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompTrans Con
inner
      where
        ctx :: [Type]
ctx = [(Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
EqualityT [(Name -> Type
VarT Name
i), (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ Name -> Name
nameLab Name
l)]]

        vsts' :: [VarBangType]
vsts'   = case Maybe a
annPropInfo of
                   Just api -> (VarBangType -> Bool) -> [VarBangType] -> [VarBangType]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (VarBangType -> Bool) -> VarBangType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
api a -> Getting (Type -> Bool) a (Type -> Bool) -> Type -> Bool
forall s a. s -> Getting a s a -> a
^. Getting (Type -> Bool) a (Type -> Bool)
forall c. HasAnnotationPropInfo c => Lens' c (Type -> Bool)
isAnnotation) (Type -> Bool) -> (VarBangType -> Type) -> VarBangType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VarBangType -> Getting Type VarBangType Type -> Type
forall s a. s -> Getting a s a -> a
^. Getting Type VarBangType Type
forall s t a b. Field3 s t a b => Lens s t a b
_3)) [VarBangType]
vsts
                   Nothing  -> [VarBangType]
vsts
        vsts'' :: [VarBangType]
vsts''  = [VarBangType]
vsts'  [VarBangType] -> ([VarBangType] -> [VarBangType]) -> [VarBangType]
forall a b. a -> (a -> b) -> b
& ((VarBangType -> Identity VarBangType)
-> [VarBangType] -> Identity [VarBangType]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse((VarBangType -> Identity VarBangType)
 -> [VarBangType] -> Identity [VarBangType])
-> ((Name -> Identity Name) -> VarBangType -> Identity VarBangType)
-> (Name -> Identity Name)
-> [VarBangType]
-> Identity [VarBangType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Name -> Identity Name) -> VarBangType -> Identity VarBangType
forall s t a b. Field1 s t a b => Lens s t a b
_1) ((Name -> Identity Name)
 -> [VarBangType] -> Identity [VarBangType])
-> (Name -> Name) -> [VarBangType] -> [VarBangType]
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Name -> Name
transName
        vsts''' :: CompTrans [VarBangType]
vsts''' = [VarBangType]
vsts'' [VarBangType]
-> ([VarBangType] -> CompTrans [VarBangType])
-> CompTrans [VarBangType]
forall a b. a -> (a -> b) -> b
& ((VarBangType -> CompTrans VarBangType)
-> [VarBangType] -> CompTrans [VarBangType]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse((VarBangType -> CompTrans VarBangType)
 -> [VarBangType] -> CompTrans [VarBangType])
-> ((Type -> CompTrans Type)
    -> VarBangType -> CompTrans VarBangType)
-> (Type -> CompTrans Type)
-> [VarBangType]
-> CompTrans [VarBangType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Type -> CompTrans Type) -> VarBangType -> CompTrans VarBangType
forall s t a b. Field3 s t a b => Lens s t a b
_3) ((Type -> CompTrans Type)
 -> [VarBangType] -> CompTrans [VarBangType])
-> (Type -> CompTrans Type)
-> [VarBangType]
-> CompTrans [VarBangType]
forall k (f :: k -> *) s (t :: k) a (b :: k).
LensLike f s t a b -> LensLike f s t a b
%%~ Name -> Type -> CompTrans Type
unfixType Name
e
        inner :: CompTrans Con
inner   = ([VarBangType] -> Con) -> CompTrans [VarBangType] -> CompTrans Con
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (Name -> [VarBangType] -> Con
RecC (Name -> Name
transName Name
n)) CompTrans [VarBangType]
vsts'''
mkCon _ _ _ c :: Con
c = String -> CompTrans Con
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> CompTrans Con) -> String -> CompTrans Con
forall a b. (a -> b) -> a -> b
$ "Attempted to derive multi-sorted compositional datatype for something with non-normal constructors: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Show a => a -> String
show Con
c

unfixType :: Name -> Type -> CompTrans Type
unfixType :: Name -> Type -> CompTrans Type
unfixType _ t :: Type
t | Type -> [Type] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Type
t [Type]
baseTypes = Type -> CompTrans Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
unfixType e :: Name
e t :: Type
t = do Type -> Bool
checkAnn <- CompTrans (Type -> Bool)
getIsAnn
                   Type
t' <- (ReaderT TransCtx Q Type -> CompTrans Type
forall a. ReaderT TransCtx Q a -> CompTrans a
CompTrans (ReaderT TransCtx Q Type -> CompTrans Type)
-> ReaderT TransCtx Q Type -> CompTrans Type
forall a b. (a -> b) -> a -> b
$ Q Type -> ReaderT TransCtx Q Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Type -> Q Type
expandSyns Type
t)) CompTrans Type -> (Type -> CompTrans Type) -> CompTrans Type
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Type -> Bool) -> Type -> CompTrans Type
getLab Type -> Bool
checkAnn
                   Type -> CompTrans Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> CompTrans Type) -> Type -> CompTrans Type
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Type
AppT (Name -> Type
VarT Name
e) Type
t'