{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE TypeSynonymInstances #-}

module Data.Comp.Trans.Collect (
    collectTypes
  ) where

import Control.Lens ( view )
import Control.Monad ( liftM, liftM2 )
import Control.Monad.Trans ( lift )

import Data.Foldable ( fold )

import Data.Set as Set ( Set, singleton, union, difference, toList, member, empty )

import Language.Haskell.TH.Syntax hiding ( lift )
import Language.Haskell.TH.ExpandSyns ( expandSyns )

import Data.Comp.Trans.Util

-- | Finds all type names transitively referred to by a given type,
-- removing standard types
collectTypes :: Name -> CompTrans [Name]
collectTypes :: Name -> CompTrans [Name]
collectTypes n :: Name
n = do Set Name
names <- (Name -> CompTrans (Set Name)) -> Name -> CompTrans (Set Name)
forall a (m :: * -> *).
(Ord a, Monad m) =>
(a -> m (Set a)) -> a -> m (Set a)
fixpoint Name -> CompTrans (Set Name)
collectTypes' Name
n
                    Set Name
exclNms <- Getting (Set Name) TransCtx (Set Name) -> CompTrans (Set Name)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting (Set Name) TransCtx (Set Name)
forall c. HasTransCtx c => Lens' c (Set Name)
excludedNames
                    [Name] -> CompTrans [Name]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Name] -> CompTrans [Name]) -> [Name] -> CompTrans [Name]
forall a b. (a -> b) -> a -> b
$ Set Name -> [Name]
forall a. Set a -> [a]
toList (Set Name -> [Name]) -> Set Name -> [Name]
forall a b. (a -> b) -> a -> b
$ Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
difference Set Name
names Set Name
exclNms

-- |
-- Finds the fixpoint of a monotone monadic function using chaotic iteration
fixpoint :: (Ord a, Monad m) => (a -> m (Set a)) -> a -> m (Set a)
fixpoint :: (a -> m (Set a)) -> a -> m (Set a)
fixpoint f :: a -> m (Set a)
f x :: a
x = Set a -> m (Set a)
run (Set a -> m (Set a)) -> Set a -> m (Set a)
forall a b. (a -> b) -> a -> b
$ a -> Set a
forall a. a -> Set a
singleton a
x
  where
    run :: Set a -> m (Set a)
run s :: Set a
s = do Set a
s' <- (Set (Set a) -> Set a) -> m (Set (Set a)) -> m (Set a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Set (Set a) -> Set a
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold (m (Set (Set a)) -> m (Set a)) -> m (Set (Set a)) -> m (Set a)
forall a b. (a -> b) -> a -> b
$ (a -> m (Set a)) -> Set a -> m (Set (Set a))
forall (m :: * -> *) b a.
(Monad m, Ord b) =>
(a -> m b) -> Set a -> m (Set b)
mapSetM a -> m (Set a)
f Set a
s
               if Set a
s' Set a -> Set a -> Bool
forall a. Eq a => a -> a -> Bool
== Set a
s then
                 Set a -> m (Set a)
forall (m :: * -> *) a. Monad m => a -> m a
return Set a
s'
                else
                 Set a -> m (Set a)
run Set a
s'

-- | mapM for Data.Set
mapSetM :: (Monad m, Ord b) => (a -> m b) -> Set a -> m (Set b)
mapSetM :: (a -> m b) -> Set a -> m (Set b)
mapSetM f :: a -> m b
f x :: Set a
x = ([b] -> Set b) -> m [b] -> m (Set b)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM ([Set b] -> Set b
forall a. Monoid a => [a] -> a
mconcat ([Set b] -> Set b) -> ([b] -> [Set b]) -> [b] -> Set b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Set b) -> [b] -> [Set b]
forall a b. (a -> b) -> [a] -> [b]
map b -> Set b
forall a. a -> Set a
singleton) (m [b] -> m (Set b)) -> m [b] -> m (Set b)
forall a b. (a -> b) -> a -> b
$ (a -> m b) -> [a] -> m [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> m b
f (Set a -> [a]
forall a. Set a -> [a]
toList Set a
x)

collectTypes' :: Name -> CompTrans (Set Name)
collectTypes' :: Name -> CompTrans (Set Name)
collectTypes' n :: Name
n = Getting (Set Name) TransCtx (Set Name) -> CompTrans (Set Name)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting (Set Name) TransCtx (Set Name)
forall c. HasTransCtx c => Lens' c (Set Name)
excludedNames CompTrans (Set Name)
-> (Set Name -> CompTrans (Set Name)) -> CompTrans (Set Name)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Set Name -> CompTrans (Set Name)
run
  where
    run :: Set Name -> CompTrans (Set Name)
    run :: Set Name -> CompTrans (Set Name)
run exclNms :: Set Name
exclNms | Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
member Name
n Set Name
exclNms = Set Name -> CompTrans (Set Name)
forall (m :: * -> *) a. Monad m => a -> m a
return Set Name
forall a. Set a
empty
    run _                          = do
      Info
inf <- ReaderT TransCtx Q Info -> CompTrans Info
forall a. ReaderT TransCtx Q a -> CompTrans a
CompTrans (ReaderT TransCtx Q Info -> CompTrans Info)
-> ReaderT TransCtx Q Info -> CompTrans Info
forall a b. (a -> b) -> a -> b
$ Q Info -> ReaderT TransCtx Q Info
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Info -> ReaderT TransCtx Q Info)
-> Q Info -> ReaderT TransCtx Q Info
forall a b. (a -> b) -> a -> b
$ Name -> Q Info
reify Name
n
      let cons :: [Con]
cons = case Info
inf of
            TyConI (DataD _ _ _ _ cns :: [Con]
cns _)    -> [Con]
cns
            TyConI (NewtypeD _ _ _ _ con :: Con
con _) -> [Con
con]
            _ -> []
      [Name]
childNames <- ([[Name]] -> [Name]) -> CompTrans [[Name]] -> CompTrans [Name]
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM [[Name]] -> [Name]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (CompTrans [[Name]] -> CompTrans [Name])
-> CompTrans [[Name]] -> CompTrans [Name]
forall a b. (a -> b) -> a -> b
$ (Con -> CompTrans [Name]) -> [Con] -> CompTrans [[Name]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Con -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames [Con]
cons
      Set Name -> CompTrans (Set Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (Set Name -> CompTrans (Set Name))
-> Set Name -> CompTrans (Set Name)
forall a b. (a -> b) -> a -> b
$ (Name -> Set Name
forall a. a -> Set a
singleton Name
n) Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
`union` ([Set Name] -> Set Name
forall a. Monoid a => [a] -> a
mconcat ([Set Name] -> Set Name) -> [Set Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ (Name -> Set Name) -> [Name] -> [Set Name]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Set Name
forall a. a -> Set a
singleton [Name]
childNames)
                    

class ExtractNames a where
  extractNames :: a -> CompTrans [Name]

instance ExtractNames Con where
  extractNames :: Con -> CompTrans [Name]
extractNames (NormalC _ xs :: [BangType]
xs) = ([[Name]] -> [Name]) -> CompTrans [[Name]] -> CompTrans [Name]
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM [[Name]] -> [Name]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (CompTrans [[Name]] -> CompTrans [Name])
-> CompTrans [[Name]] -> CompTrans [Name]
forall a b. (a -> b) -> a -> b
$ (BangType -> CompTrans [Name]) -> [BangType] -> CompTrans [[Name]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BangType -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames [BangType]
xs
  extractNames (RecC _ xs :: [VarBangType]
xs) = ([[Name]] -> [Name]) -> CompTrans [[Name]] -> CompTrans [Name]
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM [[Name]] -> [Name]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (CompTrans [[Name]] -> CompTrans [Name])
-> CompTrans [[Name]] -> CompTrans [Name]
forall a b. (a -> b) -> a -> b
$ (VarBangType -> CompTrans [Name])
-> [VarBangType] -> CompTrans [[Name]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VarBangType -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames [VarBangType]
xs
  extractNames (InfixC a :: BangType
a _ b :: BangType
b) = ([Name] -> [Name] -> [Name])
-> CompTrans [Name] -> CompTrans [Name] -> CompTrans [Name]
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
(++) (BangType -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames BangType
a) (BangType -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames BangType
b)
  extractNames (ForallC _ _ x :: Con
x) = Con -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames Con
x

instance ExtractNames StrictType where
  extractNames :: BangType -> CompTrans [Name]
extractNames (_, t :: Type
t) = Type -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames Type
t

instance ExtractNames VarStrictType where
  extractNames :: VarBangType -> CompTrans [Name]
extractNames (_, _, t :: Type
t) = Type -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames Type
t

instance ExtractNames Type where
  extractNames :: Type -> CompTrans [Name]
extractNames tSyn :: Type
tSyn = do Type
t <- ReaderT TransCtx Q Type -> CompTrans Type
forall a. ReaderT TransCtx Q a -> CompTrans a
CompTrans (ReaderT TransCtx Q Type -> CompTrans Type)
-> ReaderT TransCtx Q Type -> CompTrans Type
forall a b. (a -> b) -> a -> b
$ Q Type -> ReaderT TransCtx Q Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Q Type -> ReaderT TransCtx Q Type)
-> Q Type -> ReaderT TransCtx Q Type
forall a b. (a -> b) -> a -> b
$ Type -> Q Type
expandSyns Type
tSyn
                         case Type
t of 
                           AppT a :: Type
a b :: Type
b -> ([Name] -> [Name] -> [Name])
-> CompTrans [Name] -> CompTrans [Name] -> CompTrans [Name]
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
(++) (Type -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames Type
a) (Type -> CompTrans [Name]
forall a. ExtractNames a => a -> CompTrans [Name]
extractNames Type
b)
                           ConT n :: Name
n   -> [Name] -> CompTrans [Name]
forall (m :: * -> *) a. Monad m => a -> m a
return [Name
n]
                           _        -> [Name] -> CompTrans [Name]
forall (m :: * -> *) a. Monad m => a -> m a
return []