{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE UndecidableInstances #-}
-- 
-- This module enables the creation of "sort injections," stating that
-- one sort can be considered a coercive subsort of another
module Cubix.Language.Parametric.InjF
  (
    InjF(..)
  , injectF
  , fromProjF
  , labeledInjF
  , injFAnnDef
  , injectFAnnDef
  , type IsSortInjection
  , type SortInjectionSource
  , type SortInjectionTarget
  , RemoveSortInjectionNode(..)

  , InjectableSorts
  , AInjF(..)
  , promoteInjRF
  ) where

import Control.Monad ( MonadPlus(..), liftM )

import Data.Default ( Default )
import Data.Proxy ( Proxy(..) )
import Data.Type.Equality ( (:~:), gcastWith )

import Data.Comp.Multi ( Signature, Sort, Cxt(..), (:-<:),  (:&:), Cxt, inject, ann, stripA, HFunctor(..), HTraversable, AnnTerm, Sum, All, CxtS, HFoldable, Fragment )
import Data.Comp.Multi.Strategic ( RewriteM, GRewriteM )
import Data.Comp.Multi.Strategy.Classification ( DynCase(..), KDynCase(..) )

import Cubix.Language.Info

import Cubix.Sin.Compdata.Annotation ( MonadAnnotater, AnnotateDefault, runAnnotateDefault, annotateOuter )
import Data.Kind (Type)

--------------------------------------------------------------------------------

-- |
-- InjF allows us to create "sort injections," stating that one sort can be considered
-- a coercive subsort of another..
-- 
-- For example, if we wanted to parameterize whether a given syntax
-- allows arbitrary expressions to be used as function arguments,
-- we could have the function terms have arguments of sort "FunArg"
-- and create an "ExpressionIsFunArg" . Defining an instance
-- 
-- > instance (ExpressionIsFunArg :-<: f) => InjF fs ExpL FunArgL
-- 
-- would then allow us to use expression as function arguments freely.
class (All HFunctor fs) => InjF fs l l' where
  injF :: CxtS h fs a l -> CxtS h fs a l'

  -- |
  -- Dynamically casing on subsorts
  projF' :: Cxt h (Sum fs :&: p) a l' -> Maybe (Cxt h (Sum fs :&: p) a l)

  projF :: CxtS h fs a l' -> Maybe (CxtS h fs a l)
  projF = (Cxt h (Sum fs :&: ()) a l -> CxtS h fs a l)
-> Maybe (Cxt h (Sum fs :&: ()) a l) -> Maybe (CxtS h fs a l)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Cxt h (Sum fs :&: ()) a l -> CxtS h fs a l
CxtFun (Sum fs :&: ()) (Sum fs)
forall (g :: (* -> *) -> * -> *) (f :: (* -> *) -> * -> *).
(RemA g f, HFunctor g) =>
CxtFun g f
stripA (Maybe (Cxt h (Sum fs :&: ()) a l) -> Maybe (CxtS h fs a l))
-> (CxtS h fs a l' -> Maybe (Cxt h (Sum fs :&: ()) a l))
-> CxtS h fs a l'
-> Maybe (CxtS h fs a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cxt h (Sum fs :&: ()) a l' -> Maybe (Cxt h (Sum fs :&: ()) a l)
forall (fs :: [(* -> *) -> * -> *]) l l' h p (a :: * -> *).
InjF fs l l' =>
Cxt h (Sum fs :&: p) a l' -> Maybe (Cxt h (Sum fs :&: p) a l)
forall h p (a :: * -> *).
Cxt h (Sum fs :&: p) a l' -> Maybe (Cxt h (Sum fs :&: p) a l)
projF' (Cxt h (Sum fs :&: ()) a l' -> Maybe (Cxt h (Sum fs :&: ()) a l))
-> (CxtS h fs a l' -> Cxt h (Sum fs :&: ()) a l')
-> CxtS h fs a l'
-> Maybe (Cxt h (Sum fs :&: ()) a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. () -> CxtFun (Sum fs) (Sum fs :&: ())
forall (f :: (* -> *) -> * -> *) p.
HFunctor f =>
p -> CxtFun f (f :&: p)
ann ()

instance (All HFunctor fs) => InjF fs l l where
  injF :: All HFunctor fs => CxtS h fs a l -> CxtS h fs a l
  injF :: forall h (a :: * -> *).
All HFunctor fs =>
CxtS h fs a l -> CxtS h fs a l
injF = CxtS h fs a l -> CxtS h fs a l
forall a. a -> a
id
  projF' :: forall h p (a :: * -> *).
Cxt h (Sum fs :&: p) a l -> Maybe (Cxt h (Sum fs :&: p) a l)
projF' = Cxt h (Sum fs :&: p) a l -> Maybe (Cxt h (Sum fs :&: p) a l)
forall a. a -> Maybe a
Just

-- | 'injF' but for terms. Or 'inject', but allowing sort injections
-- We would like this to replace the 'inject' function outright
injectF :: (g :-<: fs, InjF fs l l') => g (CxtS h fs a) l -> CxtS h fs a l'
injectF :: forall (g :: (* -> *) -> * -> *) (fs :: [(* -> *) -> * -> *]) l l'
       h (a :: * -> *).
(g :-<: fs, InjF fs l l') =>
g (CxtS h fs a) l -> CxtS h fs a l'
injectF = CxtS h fs a l -> CxtS h fs a l'
forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l -> CxtS h fs a l'
forall h (a :: * -> *). CxtS h fs a l -> CxtS h fs a l'
injF (CxtS h fs a l -> CxtS h fs a l')
-> (g (CxtS h fs a) l -> CxtS h fs a l)
-> g (CxtS h fs a) l
-> CxtS h fs a l'
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g (CxtS h fs a) l -> CxtS h fs a l
g (CxtS h fs a) :-> CxtS h fs a
forall (g :: (* -> *) -> * -> *) (f :: (* -> *) -> * -> *) h
       (a :: * -> *).
(g :<: f) =>
g (Cxt h f a) :-> Cxt h f a
inject

fromProjF :: (InjF fs l l') => CxtS h fs a l' -> CxtS h fs a l
fromProjF :: forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l' -> CxtS h fs a l
fromProjF CxtS h fs a l'
x = case CxtS h fs a l' -> Maybe (CxtS h fs a l)
forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l' -> Maybe (CxtS h fs a l)
forall h (a :: * -> *). CxtS h fs a l' -> Maybe (CxtS h fs a l)
projF CxtS h fs a l'
x of
  Just CxtS h fs a l
y  -> CxtS h fs a l
y
  Maybe (CxtS h fs a l)
Nothing -> [Char] -> CxtS h fs a l
forall a. HasCallStack => [Char] -> a
error [Char]
"InjF.fromProjF"

labeledInjF :: ( MonadAnnotater Label m
              , InjF fs l l'
              , All HTraversable fs
              , All HFoldable fs
              ) => TermLab fs l -> m (TermLab fs l')
labeledInjF :: forall (m :: * -> *) (fs :: [(* -> *) -> * -> *]) l l'.
(MonadAnnotater Label m, InjF fs l l', All HTraversable fs,
 All HFoldable fs) =>
TermLab fs l -> m (TermLab fs l')
labeledInjF TermLab fs l
t = Context (Sum fs) (HFixLab (Sum fs)) l' -> m (HFixLab (Sum fs) l')
forall (f :: (* -> *) -> * -> *) (m :: * -> *) l.
(HTraversable f, MonadAnnotater Label m) =>
Context f (HFixLab f) l -> m (HFixLab f l)
annotateLabelOuter (Context (Sum fs) (HFixLab (Sum fs)) l' -> m (HFixLab (Sum fs) l'))
-> Context (Sum fs) (HFixLab (Sum fs)) l'
-> m (HFixLab (Sum fs) l')
forall a b. (a -> b) -> a -> b
$ CxtS Hole fs (HFixLab (Sum fs)) l
-> Context (Sum fs) (HFixLab (Sum fs)) l'
forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l -> CxtS h fs a l'
forall h (a :: * -> *). CxtS h fs a l -> CxtS h fs a l'
injF (CxtS Hole fs (HFixLab (Sum fs)) l
 -> Context (Sum fs) (HFixLab (Sum fs)) l')
-> CxtS Hole fs (HFixLab (Sum fs)) l
-> Context (Sum fs) (HFixLab (Sum fs)) l'
forall a b. (a -> b) -> a -> b
$ TermLab fs l -> CxtS Hole fs (HFixLab (Sum fs)) l
forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole TermLab fs l
t


injFAnnDef :: ( InjF fs l l'
             , All HTraversable fs
             , Default a
             , All HFoldable fs
             ) => AnnTerm a fs l -> AnnTerm a fs l'
injFAnnDef :: forall (fs :: [(* -> *) -> * -> *]) l l' a.
(InjF fs l l', All HTraversable fs, Default a, All HFoldable fs) =>
AnnTerm a fs l -> AnnTerm a fs l'
injFAnnDef AnnTerm a fs l
t = AnnotateDefault (AnnHFix a (Sum fs) l') -> AnnHFix a (Sum fs) l'
forall a (f :: (* -> *) -> * -> *) l.
AnnotateDefault (AnnHFix a f l) -> AnnHFix a f l
runAnnotateDefault (AnnotateDefault (AnnHFix a (Sum fs) l') -> AnnHFix a (Sum fs) l')
-> AnnotateDefault (AnnHFix a (Sum fs) l') -> AnnHFix a (Sum fs) l'
forall a b. (a -> b) -> a -> b
$ Context (Sum fs) (AnnHFix a (Sum fs)) l'
-> AnnotateDefault (AnnHFix a (Sum fs) l')
forall (f :: (* -> *) -> * -> *) a (m :: * -> *) l.
(HTraversable f, MonadAnnotater a m) =>
Context f (AnnHFix a f) l -> m (AnnHFix a f l)
annotateOuter (Context (Sum fs) (AnnHFix a (Sum fs)) l'
 -> AnnotateDefault (AnnHFix a (Sum fs) l'))
-> Context (Sum fs) (AnnHFix a (Sum fs)) l'
-> AnnotateDefault (AnnHFix a (Sum fs) l')
forall a b. (a -> b) -> a -> b
$ CxtS Hole fs (AnnHFix a (Sum fs)) l
-> Context (Sum fs) (AnnHFix a (Sum fs)) l'
forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l -> CxtS h fs a l'
forall h (a :: * -> *). CxtS h fs a l -> CxtS h fs a l'
injF (CxtS Hole fs (AnnHFix a (Sum fs)) l
 -> Context (Sum fs) (AnnHFix a (Sum fs)) l')
-> CxtS Hole fs (AnnHFix a (Sum fs)) l
-> Context (Sum fs) (AnnHFix a (Sum fs)) l'
forall a b. (a -> b) -> a -> b
$ AnnTerm a fs l -> CxtS Hole fs (AnnHFix a (Sum fs)) l
forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole AnnTerm a fs l
t

injectFAnnDef :: ( InjF fs l l'
                 , f :-<: fs
                 , All HTraversable fs
                 , All HFoldable fs
                 , Default a
                ) => (f :&: a) (AnnTerm a fs) l -> AnnTerm a fs l'
injectFAnnDef :: forall (fs :: [(* -> *) -> * -> *]) l l' (f :: (* -> *) -> * -> *)
       a.
(InjF fs l l', f :-<: fs, All HTraversable fs, All HFoldable fs,
 Default a) =>
(:&:) f a (AnnTerm a fs) l -> AnnTerm a fs l'
injectFAnnDef =  AnnTerm a fs l -> AnnTerm a fs l'
forall (fs :: [(* -> *) -> * -> *]) l l' a.
(InjF fs l l', All HTraversable fs, Default a, All HFoldable fs) =>
AnnTerm a fs l -> AnnTerm a fs l'
injFAnnDef (AnnTerm a fs l -> AnnTerm a fs l')
-> ((:&:) f a (AnnTerm a fs) l -> AnnTerm a fs l)
-> (:&:) f a (AnnTerm a fs) l
-> AnnTerm a fs l'
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:&:) f a (AnnTerm a fs) l -> AnnTerm a fs l
(:&:) f a (AnnTerm a fs) :-> AnnTerm a fs
forall (g :: (* -> *) -> * -> *) (f :: (* -> *) -> * -> *) h
       (a :: * -> *).
(g :<: f) =>
g (Cxt h f a) :-> Cxt h f a
inject

type family IsSortInjection (f :: Fragment) :: Bool
type family SortInjectionSource (f :: Fragment) :: Sort
type family SortInjectionTarget (f :: Fragment) :: Sort

class RemoveSortInjectionNode f where
  removeSortInjectionNode :: f (Cxt h fs a) l -> Cxt h fs a (SortInjectionSource f)

--------------------------------------------------------------------------------


type family InjectableSorts (fs :: Signature) (l :: Sort) :: [Sort]

-- NOTE: There should be some way to express this in terms of the Lens library
class AInjF fs l where
  ainjF :: (MonadAnnotater Label m) => TermLab fs l' -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))

class AInjF' fs l (is :: [Sort]) where
  ainjF' :: (MonadAnnotater Label m) => Proxy is -> TermLab fs l' -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))

instance AInjF' fs l '[] where
  ainjF' :: forall (m :: * -> *) l'.
MonadAnnotater Label m =>
Proxy '[]
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF' Proxy '[]
_ TermLab fs l'
_ = Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall a. Maybe a
Nothing

instance ( All HTraversable fs
         , All HFoldable fs
         , AInjF' fs l is
         , InjF fs l i
         , KDynCase (Sum fs) i
         ) => AInjF' fs l (i ': is) where
  -- NOTE: Moved application of dyncase into the where clause (dcase)
  --       because the constraint KDynCase given in context was getting
  --       ignored.
  ainjF' :: forall (m :: * -> *) l'.
MonadAnnotater Label m =>
Proxy (i : is)
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF' Proxy (i : is)
_ TermLab fs l'
x = case TermLab fs l' -> Maybe (l' :~: i)
forall li. KDynCase (Sum fs) i => TermLab fs li -> Maybe (li :~: i)
dcase TermLab fs l'
x of
      Just l' :~: i
p  -> (l' :~: i)
-> ((l' ~ i) =>
    TermLab fs l'
    -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l')))
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall {k} (a :: k) (b :: k) r. (a :~: b) -> ((a ~ b) => r) -> r
gcastWith l' :~: i
p (l' ~ i) =>
TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
TermLab fs i
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (m :: * -> *).
MonadAnnotater Label m =>
TermLab fs i
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
spec TermLab fs l'
x
      Maybe (l' :~: i)
Nothing -> Proxy is
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (fs :: [(* -> *) -> * -> *]) l (is :: [*]) (m :: * -> *) l'.
(AInjF' fs l is, MonadAnnotater Label m) =>
Proxy is
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (m :: * -> *) l'.
MonadAnnotater Label m =>
Proxy is
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF' (forall (t :: [*]). Proxy t
forall {k} (t :: k). Proxy t
Proxy @is) TermLab fs l'
x
    where
      dcase :: forall li. (KDynCase (Sum fs) i) => TermLab fs li -> Maybe (li :~: i)
      dcase :: forall li. KDynCase (Sum fs) i => TermLab fs li -> Maybe (li :~: i)
dcase TermLab fs li
a = TermLab fs li -> Maybe (li :~: i)
forall b. Cxt NoHole (Sum fs :&: Label) (K ()) b -> Maybe (b :~: i)
forall (f :: * -> *) a b. DynCase f a => f b -> Maybe (b :~: a)
dyncase TermLab fs li
a :: Maybe (li :~: i)
      spec :: (MonadAnnotater Label m) => TermLab fs i -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
      spec :: forall (m :: * -> *).
MonadAnnotater Label m =>
TermLab fs i
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
spec TermLab fs i
t = case TermLab fs i -> Maybe (TermLab fs l)
forall (fs :: [(* -> *) -> * -> *]) l l' h p (a :: * -> *).
InjF fs l l' =>
Cxt h (Sum fs :&: p) a l' -> Maybe (Cxt h (Sum fs :&: p) a l)
forall h p (a :: * -> *).
Cxt h (Sum fs :&: p) a i -> Maybe (Cxt h (Sum fs :&: p) a l)
projF' TermLab fs i
t of
        Just TermLab fs l
t' -> (TermLab fs l, TermLab fs l -> m (TermLab fs i))
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
forall a. a -> Maybe a
Just (TermLab fs l
t', TermLab fs l -> m (TermLab fs i)
forall (m :: * -> *) (fs :: [(* -> *) -> * -> *]) l l'.
(MonadAnnotater Label m, InjF fs l l', All HTraversable fs,
 All HFoldable fs) =>
TermLab fs l -> m (TermLab fs l')
labeledInjF)
        Maybe (TermLab fs l)
Nothing -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
forall a. Maybe a
Nothing

instance {-# OVERLAPPABLE #-} (AInjF' fs l (InjectableSorts fs l)) => AInjF fs l where
  ainjF :: forall (m :: * -> *) l'.
MonadAnnotater Label m =>
TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF = Proxy (InjectableSorts fs l)
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (fs :: [(* -> *) -> * -> *]) l (is :: [*]) (m :: * -> *) l'.
(AInjF' fs l is, MonadAnnotater Label m) =>
Proxy is
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (m :: * -> *) l'.
MonadAnnotater Label m =>
Proxy (InjectableSorts fs l)
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF' (forall (t :: [*]). Proxy t
forall {k} (t :: k). Proxy t
Proxy @(InjectableSorts fs l))


promoteInjRF :: (AInjF fs l, MonadPlus m, MonadAnnotater Label m) => RewriteM m (TermLab fs) l -> GRewriteM m (TermLab fs)
promoteInjRF :: forall (fs :: [(* -> *) -> * -> *]) l (m :: * -> *).
(AInjF fs l, MonadPlus m, MonadAnnotater Label m) =>
RewriteM m (TermLab fs) l -> GRewriteM m (TermLab fs)
promoteInjRF RewriteM m (TermLab fs) l
f TermLab fs l
t = case TermLab fs l
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l))
forall (fs :: [(* -> *) -> * -> *]) l (m :: * -> *) l'.
(AInjF fs l, MonadAnnotater Label m) =>
TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (m :: * -> *) l'.
MonadAnnotater Label m =>
TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF TermLab fs l
t of
  Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l))
Nothing        -> m (TermLab fs l)
forall a. m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
  Just (TermLab fs l
t', TermLab fs l -> m (TermLab fs l)
ins) -> TermLab fs l -> m (TermLab fs l)
ins (TermLab fs l -> m (TermLab fs l))
-> m (TermLab fs l) -> m (TermLab fs l)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< RewriteM m (TermLab fs) l
f TermLab fs l
t'