{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TemplateHaskell   #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Multi.Derive.Equality
-- Copyright   :  (c) 2011 Patrick Bahr
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- Automatically derive instances of @EqHF@.
--
--------------------------------------------------------------------------------
module Data.Comp.Multi.Derive.Equality
    (
     EqHF(..),
     KEq(..),
     makeEqHF
    ) where

import Data.Comp.Derive.Utils
import Data.Comp.Multi.Equality
import Language.Haskell.TH hiding (Cxt, match)

{-| Derive an instance of 'EqHF' for a type constructor of any higher-order
  kind taking at least two arguments. -}
makeEqHF :: Name -> Q [Dec]
makeEqHF :: Name -> Q [Dec]
makeEqHF fname :: Name
fname = do
  Just (DataInfo _cxt :: Cxt
_cxt name :: Name
name args :: [TyVarBndr]
args constrs :: [Con]
constrs _deriving :: [DerivClause]
_deriving) <- Q Info -> Q (Maybe DataInfo)
abstractNewtypeQ (Q Info -> Q (Maybe DataInfo)) -> Q Info -> Q (Maybe DataInfo)
forall a b. (a -> b) -> a -> b
$ Name -> Q Info
reify Name
fname
  let args' :: [TyVarBndr]
args' = [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a]
init [TyVarBndr]
args
      argNames :: Cxt
argNames = (TyVarBndr -> Type) -> [TyVarBndr] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Type
VarT (Name -> Type) -> (TyVarBndr -> Name) -> TyVarBndr -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr -> Name
tyVarBndrName) ([TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a]
init [TyVarBndr]
args')
      ftyp :: Type
ftyp = Name -> Type
VarT (Name -> Type) -> (TyVarBndr -> Name) -> TyVarBndr -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr -> Name
tyVarBndrName (TyVarBndr -> Type) -> TyVarBndr -> Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr] -> TyVarBndr
forall a. [a] -> a
last [TyVarBndr]
args'
      complType :: Type
complType = (Type -> Type -> Type) -> Type -> Cxt -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
name) Cxt
argNames
      preCond :: Cxt
preCond = (Type -> Type) -> Cxt -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Cxt -> Type
mkClassP ''Eq (Cxt -> Type) -> (Type -> Cxt) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Cxt -> Cxt
forall a. a -> [a] -> [a]
: [])) Cxt
argNames
      classType :: Type
classType = Type -> Type -> Type
AppT (Name -> Type
ConT ''EqHF) Type
complType
  [(Name, Cxt, Maybe Type)]
constrs' <- (Con -> Q (Name, Cxt, Maybe Type))
-> [Con] -> Q [(Name, Cxt, Maybe Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Con -> Q (Name, Cxt, Maybe Type)
normalConExp [Con]
constrs
  Dec
eqFDecl <- Name -> [ClauseQ] -> DecQ
funD 'eqHF  (Type -> [Con] -> [(Name, Cxt, Maybe Type)] -> [ClauseQ]
forall (t :: * -> *) a.
Foldable t =>
Type -> t a -> [(Name, Cxt, Maybe Type)] -> [ClauseQ]
eqFClauses Type
ftyp [Con]
constrs [(Name, Cxt, Maybe Type)]
constrs')
  [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Cxt -> Type -> [Dec] -> Dec
mkInstanceD Cxt
preCond Type
classType [Dec
eqFDecl]]
      where eqFClauses :: Type -> t a -> [(Name, Cxt, Maybe Type)] -> [ClauseQ]
eqFClauses ftyp :: Type
ftyp constrs :: t a
constrs constrs' :: [(Name, Cxt, Maybe Type)]
constrs' = ((Name, Cxt, Maybe Type) -> ClauseQ)
-> [(Name, Cxt, Maybe Type)] -> [ClauseQ]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> (Name, Cxt, Maybe Type) -> ClauseQ
genEqClause Type
ftyp) [(Name, Cxt, Maybe Type)]
constrs'
                                   [ClauseQ] -> [ClauseQ] -> [ClauseQ]
forall a. [a] -> [a] -> [a]
++ t a -> [ClauseQ]
forall (t :: * -> *) a. Foldable t => t a -> [ClauseQ]
defEqClause t a
constrs
            defEqClause :: t a -> [ClauseQ]
defEqClause constrs :: t a
constrs
                | t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
constrs  Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 2 = []
                | Bool
otherwise = [[PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause [PatQ
wildP,PatQ
wildP] (ExpQ -> BodyQ
normalB [|False|]) []]
            genEqClause :: Type -> (Name, Cxt, Maybe Type) -> ClauseQ
genEqClause ftyp :: Type
ftyp (constr :: Name
constr, argts :: Cxt
argts, gadtTy :: Maybe Type
gadtTy) = do
              let n :: Int
n = Cxt -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Cxt
argts
              [Name]
varNs <- Int -> String -> Q [Name]
newNames Int
n "x"
              [Name]
varNs' <- Int -> String -> Q [Name]
newNames Int
n "y"
              let pat :: Pat
pat = Name -> [Pat] -> Pat
ConP Name
constr ([Pat] -> Pat) -> [Pat] -> Pat
forall a b. (a -> b) -> a -> b
$ (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
varNs
                  pat' :: Pat
pat' = Name -> [Pat] -> Pat
ConP Name
constr ([Pat] -> Pat) -> [Pat] -> Pat
forall a b. (a -> b) -> a -> b
$ (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
varNs'
                  vars :: [Exp]
vars = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
varNs
                  vars' :: [Exp]
vars' = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
varNs'
                  mkEq :: Type -> Exp -> Exp -> ExpQ
mkEq ty :: Type
ty x :: Exp
x y :: Exp
y = let (x' :: ExpQ
x',y' :: ExpQ
y') = (Exp -> ExpQ
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
x,Exp -> ExpQ
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
y)
                                in if Type -> Type -> Bool
containsType Type
ty (Type -> Maybe Type -> Type
getBinaryFArg Type
ftyp Maybe Type
gadtTy)
                                   then [| $x' `keq` $y'|]
                                   else [| $x' == $y'|]
                  eqs :: ExpQ
eqs = [ExpQ] -> ExpQ
listE ([ExpQ] -> ExpQ) -> [ExpQ] -> ExpQ
forall a b. (a -> b) -> a -> b
$ (Type -> Exp -> Exp -> ExpQ) -> Cxt -> [Exp] -> [Exp] -> [ExpQ]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Type -> Exp -> Exp -> ExpQ
mkEq Cxt
argts [Exp]
vars [Exp]
vars'
              Exp
body <- if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0
                      then [|True|]
                      else [|and $eqs|]
              Clause -> ClauseQ
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> ClauseQ) -> Clause -> ClauseQ
forall a b. (a -> b) -> a -> b
$ [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
pat, Pat
pat'] (Exp -> Body
NormalB Exp
body) []