{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}

--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Multi.Equality
-- Copyright   :  (c) Patrick Bahr, 2011
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- This module defines equality for (higher-order) signatures, which lifts to
-- equality for (higher-order) terms and contexts. All definitions are
-- generalised versions of those in "Data.Comp.Equality".
--
--------------------------------------------------------------------------------
module Data.Comp.Multi.Equality
    (
     EqHF(..),
     KEq(..),
     heqMod
    ) where

import Data.Comp.Multi.HFoldable
import Data.Comp.Multi.HFunctor
import Data.Comp.Multi.Ops
import Data.Comp.Multi.Term
import Data.Comp.Dict
import Data.Comp.Elem
import Data.Type.Equality

class KEq f where
    keq :: f i -> f j -> Bool

{-| Signature equality. An instance @EqHF f@ gives rise to an instance
  @KEq (HTerm f)@. -}
class EqHF f where
    eqHF :: KEq g => f g i -> f g j -> Bool

instance Eq a => KEq (K a) where
    keq :: K a i -> K a j -> Bool
keq (K x :: a
x) (K y :: a
y) = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y

instance KEq a => Eq (E a) where
     E x :: a i
x == :: E a -> E a -> Bool
== E y :: a i
y = a i
x a i -> a i -> Bool
forall (f :: * -> *) i j. KEq f => f i -> f j -> Bool
`keq`  a i
y

{-|
  'EqF' is propagated through sums.
-}
instance (All EqHF fs) => EqHF (Sum fs) where
    eqHF :: Sum fs g i -> Sum fs g j -> Bool
eqHF (Sum wit1 :: Elem f fs
wit1 x :: f g i
x) (Sum wit2 :: Elem f fs
wit2 y :: f g j
y) =
      case Elem f fs -> Elem f fs -> Maybe (f :~: f)
forall k (f :: k) (g :: k) (fs :: [k]).
Elem f fs -> Elem g fs -> Maybe (f :~: g)
elemEq Elem f fs
wit1 Elem f fs
wit2 of
              Just Refl -> f g i -> f g j -> Bool
forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
eqHF f g i
x f g j
f g j
y (EqHF f => Bool) -> Dict EqHF f -> Bool
forall k (c :: k -> Constraint) (a :: k) r.
(c a => r) -> Dict c a -> r
\\ Elem f fs -> Dict EqHF f
forall k (c :: k -> Constraint) (f :: k) (fs :: [k]).
All c fs =>
Elem f fs -> Dict c f
dictFor @EqHF Elem f fs
wit1
              Nothing   -> Bool
False

instance EqHF f => EqHF (Cxt h f) where
    eqHF :: Cxt h f g i -> Cxt h f g j -> Bool
eqHF (Term e1 :: f (Cxt h f g) i
e1) (Term e2 :: f (Cxt h f g) j
e2) = f (Cxt h f g) i
e1 f (Cxt h f g) i -> f (Cxt h f g) j -> Bool
forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
`eqHF` f (Cxt h f g) j
e2
    eqHF (Hole h1 :: g i
h1) (Hole h2 :: g j
h2) = g i
h1 g i -> g j -> Bool
forall (f :: * -> *) i j. KEq f => f i -> f j -> Bool
`keq` g j
h2
    eqHF _ _ = Bool
False

instance (EqHF f, KEq a) => KEq (Cxt h f a) where
    keq :: Cxt h f a i -> Cxt h f a j -> Bool
keq = Cxt h f a i -> Cxt h f a j -> Bool
forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
eqHF

{-|
  From an 'EqF' functor an 'Eq' instance of the corresponding
  term type can be derived.
-}
instance (EqHF f, KEq a) => Eq (Cxt h f a i) where
    == :: Cxt h f a i -> Cxt h f a i -> Bool
(==) = Cxt h f a i -> Cxt h f a i -> Bool
forall (f :: * -> *) i j. KEq f => f i -> f j -> Bool
keq

{-| This function implements equality of values of type @f a@ modulo
the equality of @a@ itself. If two functorial values are equal in this
sense, 'eqMod' returns a 'Just' value containing a list of pairs
consisting of corresponding components of the two functorial
values. -}

heqMod :: (EqHF f, HFunctor f, HFoldable f) => f a i -> f b i -> Maybe [(E a, E b)]
heqMod :: f a i -> f b i -> Maybe [(E a, E b)]
heqMod s :: f a i
s t :: f b i
t
    | f a i -> f (K ()) i
forall (f :: * -> *) i. f f i -> f (K ()) i
unit f a i
s f (K ()) i -> f (K ()) i -> Bool
forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
`eqHF` f b i -> f (K ()) i
forall (f :: * -> *) i. f f i -> f (K ()) i
unit' f b i
t = [(E a, E b)] -> Maybe [(E a, E b)]
forall a. a -> Maybe a
Just [(E a, E b)]
args
    | Bool
otherwise = Maybe [(E a, E b)]
forall a. Maybe a
Nothing
    where unit :: f f i -> f (K ()) i
unit = (f :-> K ()) -> forall i. f f i -> f (K ()) i
forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap (K () i -> f i -> K () i
forall a b. a -> b -> a
const (K () i -> f i -> K () i) -> K () i -> f i -> K () i
forall a b. (a -> b) -> a -> b
$ () -> K () i
forall a i. a -> K a i
K ())
          unit' :: f f i -> f (K ()) i
unit' = (f :-> K ()) -> f f :-> f (K ())
forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap (K () i -> f i -> K () i
forall a b. a -> b -> a
const (K () i -> f i -> K () i) -> K () i -> f i -> K () i
forall a b. (a -> b) -> a -> b
$ () -> K () i
forall a i. a -> K a i
K ())
          args :: [(E a, E b)]
args = f a i -> [E a]
forall (f :: (* -> *) -> * -> *) (a :: * -> *).
HFoldable f =>
f a :=> [E a]
htoList f a i
s [E a] -> [E b] -> [(E a, E b)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` f b i -> [E b]
forall (f :: (* -> *) -> * -> *) (a :: * -> *).
HFoldable f =>
f a :=> [E a]
htoList f b i
t