{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-- 
-- This module enables the creation of "sort injections," stating that
-- one sort can be considered a coercive subsort of another
module Cubix.Language.Parametric.InjF
  (
    InjF(..)
  , injectF
  , fromProjF
  , labeledInjF
  , injFAnnDef
  , injectFAnnDef

  , InjectableSorts
  , AInjF(..)
  , promoteInjRF
  ) where

import Control.Monad ( MonadPlus(..), liftM )

import Data.Proxy ( Proxy(..) )
import Data.Type.Equality ( (:~:), gcastWith )

import Data.Comp.Multi ( Cxt(..), (:-<:),  (:&:), Cxt, inject, ann, stripA, HFunctor(..), HTraversable, AnnTerm, Sum, All, CxtS, HFoldable )
import Data.Comp.Multi.Strategic ( RewriteM, GRewriteM )
import Data.Comp.Multi.Strategy.Classification ( DynCase(..), KDynCase(..) )

import Cubix.Language.Info

import Cubix.Sin.Compdata.Annotation ( MonadAnnotater, AnnotateDefault, runAnnotateDefault )

--------------------------------------------------------------------------------

-- |
-- InjF allows us to create "sort injections," stating that one sort can be considered
-- a coercive subsort of another..
-- 
-- For example, if we wanted to parameterize whether a given syntax
-- allows arbitrary expressions to be used as function arguments,
-- we could have the function terms have arguments of sort "FunArg"
-- and create an "ExpressionIsFunArg" . Defining an instance
-- 
-- > instance (ExpressionIsFunArg :-<: f) => InjF fs ExpL FunArgL
-- 
-- would then allow us to use expression as function arguments freely.
class (All HFunctor fs) => InjF fs l l' where
  injF :: CxtS h fs a l -> CxtS h fs a l'

  -- |
  -- Dynamically casing on subsorts
  projF' :: Cxt h (Sum fs :&: p) a l' -> Maybe (Cxt h (Sum fs :&: p) a l)

  projF :: CxtS h fs a l' -> Maybe (CxtS h fs a l)
  projF = (Cxt h (Sum fs :&: ()) a l -> CxtS h fs a l)
-> Maybe (Cxt h (Sum fs :&: ()) a l) -> Maybe (CxtS h fs a l)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Cxt h (Sum fs :&: ()) a l -> CxtS h fs a l
forall (g :: (* -> *) -> * -> *) (f :: (* -> *) -> * -> *).
(RemA g f, HFunctor g) =>
CxtFun g f
stripA (Maybe (Cxt h (Sum fs :&: ()) a l) -> Maybe (CxtS h fs a l))
-> (CxtS h fs a l' -> Maybe (Cxt h (Sum fs :&: ()) a l))
-> CxtS h fs a l'
-> Maybe (CxtS h fs a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cxt h (Sum fs :&: ()) a l' -> Maybe (Cxt h (Sum fs :&: ()) a l)
forall (fs :: [(* -> *) -> * -> *]) l l' h p (a :: * -> *).
InjF fs l l' =>
Cxt h (Sum fs :&: p) a l' -> Maybe (Cxt h (Sum fs :&: p) a l)
projF' (Cxt h (Sum fs :&: ()) a l' -> Maybe (Cxt h (Sum fs :&: ()) a l))
-> (CxtS h fs a l' -> Cxt h (Sum fs :&: ()) a l')
-> CxtS h fs a l'
-> Maybe (Cxt h (Sum fs :&: ()) a l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. () -> CxtFun (Sum fs) (Sum fs :&: ())
forall (f :: (* -> *) -> * -> *) p.
HFunctor f =>
p -> CxtFun f (f :&: p)
ann ()

instance (All HFunctor fs) => InjF fs l l where
  injF :: CxtS h fs a l -> CxtS h fs a l
injF = CxtS h fs a l -> CxtS h fs a l
forall a. a -> a
id
  projF' :: Cxt h (Sum fs :&: p) a l -> Maybe (Cxt h (Sum fs :&: p) a l)
projF' = Cxt h (Sum fs :&: p) a l -> Maybe (Cxt h (Sum fs :&: p) a l)
forall a. a -> Maybe a
Just

-- | 'injF' but for terms. Or 'inject', but allowing sort injections
-- We would like this to replace the 'inject' function outright
injectF :: (g :-<: fs, InjF fs l l') => g (CxtS h fs a) l -> CxtS h fs a l'
injectF :: g (CxtS h fs a) l -> CxtS h fs a l'
injectF = CxtS h fs a l -> CxtS h fs a l'
forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l -> CxtS h fs a l'
injF (CxtS h fs a l -> CxtS h fs a l')
-> (g (CxtS h fs a) l -> CxtS h fs a l)
-> g (CxtS h fs a) l
-> CxtS h fs a l'
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g (CxtS h fs a) l -> CxtS h fs a l
forall (g :: (* -> *) -> * -> *) (f :: (* -> *) -> * -> *) h
       (a :: * -> *).
(g :<: f) =>
g (Cxt h f a) :-> Cxt h f a
inject

fromProjF :: (InjF fs l l') => CxtS h fs a l' -> CxtS h fs a l
fromProjF :: CxtS h fs a l' -> CxtS h fs a l
fromProjF x :: CxtS h fs a l'
x = case CxtS h fs a l' -> Maybe (CxtS h fs a l)
forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l' -> Maybe (CxtS h fs a l)
projF CxtS h fs a l'
x of
  Just y :: CxtS h fs a l
y  -> CxtS h fs a l
y
  Nothing -> [Char] -> CxtS h fs a l
forall a. HasCallStack => [Char] -> a
error "InjF.fromProjF"

labeledInjF :: ( MonadAnnotater Label m
              , InjF fs l l'
              , All HTraversable fs
              , All HFoldable fs
              ) => TermLab fs l -> m (TermLab fs l')
labeledInjF :: TermLab fs l -> m (TermLab fs l')
labeledInjF t :: TermLab fs l
t = Context (Sum fs) (HFixLab (Sum fs)) l' -> m (TermLab fs l')
forall (f :: (* -> *) -> * -> *) (m :: * -> *) l.
(HTraversable f, MonadAnnotater Label m) =>
Context f (HFixLab f) l -> m (HFixLab f l)
annotateLabelOuter (Context (Sum fs) (HFixLab (Sum fs)) l' -> m (TermLab fs l'))
-> Context (Sum fs) (HFixLab (Sum fs)) l' -> m (TermLab fs l')
forall a b. (a -> b) -> a -> b
$ CxtS Hole fs (HFixLab (Sum fs)) l
-> Context (Sum fs) (HFixLab (Sum fs)) l'
forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l -> CxtS h fs a l'
injF (CxtS Hole fs (HFixLab (Sum fs)) l
 -> Context (Sum fs) (HFixLab (Sum fs)) l')
-> CxtS Hole fs (HFixLab (Sum fs)) l
-> Context (Sum fs) (HFixLab (Sum fs)) l'
forall a b. (a -> b) -> a -> b
$ TermLab fs l -> CxtS Hole fs (HFixLab (Sum fs)) l
forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole TermLab fs l
t

-- This MonadAnnotater instance leaks because it's technically possible to define a MonadLabeller
-- instance for AnnotateDefault. Gah!
-- FIXME: Anything that can be done about this?
injFAnnDef :: ( InjF fs l l'
             , All HTraversable fs
             , MonadAnnotater a (AnnotateDefault a)
             , All HFoldable fs
             ) => AnnTerm a fs l -> AnnTerm a fs l'
injFAnnDef :: AnnTerm a fs l -> AnnTerm a fs l'
injFAnnDef t :: AnnTerm a fs l
t = AnnotateDefault a (AnnTerm a fs l') -> AnnTerm a fs l'
forall a (f :: (* -> *) -> * -> *) l.
AnnotateDefault a (AnnHFix a f l) -> AnnHFix a f l
runAnnotateDefault (AnnotateDefault a (AnnTerm a fs l') -> AnnTerm a fs l')
-> AnnotateDefault a (AnnTerm a fs l') -> AnnTerm a fs l'
forall a b. (a -> b) -> a -> b
$ Context (Sum fs) (AnnHFix a (Sum fs)) l'
-> AnnotateDefault a (AnnTerm a fs l')
forall (f :: (* -> *) -> * -> *) a (m :: * -> *) l.
(HTraversable f, MonadAnnotater a m) =>
Context f (AnnHFix a f) l -> m (AnnHFix a f l)
annotateOuter (Context (Sum fs) (AnnHFix a (Sum fs)) l'
 -> AnnotateDefault a (AnnTerm a fs l'))
-> Context (Sum fs) (AnnHFix a (Sum fs)) l'
-> AnnotateDefault a (AnnTerm a fs l')
forall a b. (a -> b) -> a -> b
$ CxtS Hole fs (AnnHFix a (Sum fs)) l
-> Context (Sum fs) (AnnHFix a (Sum fs)) l'
forall (fs :: [(* -> *) -> * -> *]) l l' h (a :: * -> *).
InjF fs l l' =>
CxtS h fs a l -> CxtS h fs a l'
injF (CxtS Hole fs (AnnHFix a (Sum fs)) l
 -> Context (Sum fs) (AnnHFix a (Sum fs)) l')
-> CxtS Hole fs (AnnHFix a (Sum fs)) l
-> Context (Sum fs) (AnnHFix a (Sum fs)) l'
forall a b. (a -> b) -> a -> b
$ AnnTerm a fs l -> CxtS Hole fs (AnnHFix a (Sum fs)) l
forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole AnnTerm a fs l
t

injectFAnnDef :: ( InjF fs l l'
                , All HTraversable fs
                , MonadAnnotater a (AnnotateDefault a)
                , All HFoldable fs
                ) => (Sum fs :&: a) (AnnTerm a fs) l -> AnnTerm a fs l'
injectFAnnDef :: (:&:) (Sum fs) a (AnnTerm a fs) l -> AnnTerm a fs l'
injectFAnnDef =  AnnTerm a fs l -> AnnTerm a fs l'
forall (fs :: [(* -> *) -> * -> *]) l l' a.
(InjF fs l l', All HTraversable fs,
 MonadAnnotater a (AnnotateDefault a), All HFoldable fs) =>
AnnTerm a fs l -> AnnTerm a fs l'
injFAnnDef (AnnTerm a fs l -> AnnTerm a fs l')
-> ((:&:) (Sum fs) a (AnnTerm a fs) l -> AnnTerm a fs l)
-> (:&:) (Sum fs) a (AnnTerm a fs) l
-> AnnTerm a fs l'
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:&:) (Sum fs) a (AnnTerm a fs) l -> AnnTerm a fs l
forall (g :: (* -> *) -> * -> *) (f :: (* -> *) -> * -> *) h
       (a :: * -> *).
(g :<: f) =>
g (Cxt h f a) :-> Cxt h f a
inject

--------------------------------------------------------------------------------


type family InjectableSorts (sig :: [(* -> *) -> * -> *]) (sort :: *) :: [*]

-- NOTE: There should be some way to express this in terms of the Lens library
class AInjF fs l where
  ainjF :: (MonadAnnotater Label m) => TermLab fs l' -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))

class AInjF' fs l (is :: [*]) where
  ainjF' :: (MonadAnnotater Label m) => Proxy is -> TermLab fs l' -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))

instance AInjF' fs l '[] where
  ainjF' :: Proxy '[]
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF' _ _ = Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall a. Maybe a
Nothing

instance ( All HTraversable fs
         , All HFoldable fs
         , AInjF' fs l is
         , InjF fs l i
         , KDynCase (Sum fs) i
         ) => AInjF' fs l (i ': is) where
  -- NOTE: Moved application of dyncase into the where clause (dcase)
  --       because the constraint KDynCase given in context was getting
  --       ignored.
  ainjF' :: Proxy (i : is)
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF' _ x :: TermLab fs l'
x = case TermLab fs l' -> Maybe (l' :~: i)
forall li. KDynCase (Sum fs) i => TermLab fs li -> Maybe (li :~: i)
dcase TermLab fs l'
x of
      Just p :: l' :~: i
p  -> (l' :~: i)
-> ((l' ~ i) =>
    TermLab fs l'
    -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l')))
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall k (a :: k) (b :: k) r. (a :~: b) -> ((a ~ b) => r) -> r
gcastWith l' :~: i
p (l' ~ i) =>
TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (m :: * -> *).
MonadAnnotater Label m =>
TermLab fs i
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
spec TermLab fs l'
x
      Nothing -> Proxy is
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (fs :: [(* -> *) -> * -> *]) l (is :: [*]) (m :: * -> *) l'.
(AInjF' fs l is, MonadAnnotater Label m) =>
Proxy is
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF' (Proxy is
forall k (t :: k). Proxy t
Proxy :: Proxy is) TermLab fs l'
x
    where
      dcase :: forall li. (KDynCase (Sum fs) i) => TermLab fs li -> Maybe (li :~: i)
      dcase :: TermLab fs li -> Maybe (li :~: i)
dcase a :: TermLab fs li
a = TermLab fs li -> Maybe (li :~: i)
forall (f :: * -> *) a b. DynCase f a => f b -> Maybe (b :~: a)
dyncase TermLab fs li
a :: Maybe (li :~: i)
      spec :: (MonadAnnotater Label m) => TermLab fs i -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
      spec :: TermLab fs i
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
spec t :: TermLab fs i
t = case TermLab fs i -> Maybe (TermLab fs l)
forall (fs :: [(* -> *) -> * -> *]) l l' h p (a :: * -> *).
InjF fs l l' =>
Cxt h (Sum fs :&: p) a l' -> Maybe (Cxt h (Sum fs :&: p) a l)
projF' TermLab fs i
t of
        Just t' :: TermLab fs l
t' -> (TermLab fs l, TermLab fs l -> m (TermLab fs i))
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
forall a. a -> Maybe a
Just (TermLab fs l
t', TermLab fs l -> m (TermLab fs i)
forall (m :: * -> *) (fs :: [(* -> *) -> * -> *]) l l'.
(MonadAnnotater Label m, InjF fs l l', All HTraversable fs,
 All HFoldable fs) =>
TermLab fs l -> m (TermLab fs l')
labeledInjF)
        Nothing -> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs i))
forall a. Maybe a
Nothing

instance {-# OVERLAPPABLE #-} (AInjF' fs l (InjectableSorts fs l)) => AInjF fs l where
  ainjF :: TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF = Proxy (InjectableSorts fs l)
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
forall (fs :: [(* -> *) -> * -> *]) l (is :: [*]) (m :: * -> *) l'.
(AInjF' fs l is, MonadAnnotater Label m) =>
Proxy is
-> TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF' (Proxy (InjectableSorts fs l)
forall k (t :: k). Proxy t
Proxy :: Proxy (InjectableSorts fs l))


promoteInjRF :: (AInjF fs l, MonadPlus m, MonadAnnotater Label m) => RewriteM m (TermLab fs) l -> GRewriteM m (TermLab fs)
promoteInjRF :: RewriteM m (TermLab fs) l -> GRewriteM m (TermLab fs)
promoteInjRF f :: RewriteM m (TermLab fs) l
f t :: TermLab fs l
t = case TermLab fs l
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l))
forall (fs :: [(* -> *) -> * -> *]) l (m :: * -> *) l'.
(AInjF fs l, MonadAnnotater Label m) =>
TermLab fs l'
-> Maybe (TermLab fs l, TermLab fs l -> m (TermLab fs l'))
ainjF TermLab fs l
t of
  Nothing        -> m (TermLab fs l)
forall (m :: * -> *) a. MonadPlus m => m a
mzero
  Just (t' :: TermLab fs l
t', ins :: TermLab fs l -> m (TermLab fs l)
ins) -> TermLab fs l -> m (TermLab fs l)
ins (TermLab fs l -> m (TermLab fs l))
-> m (TermLab fs l) -> m (TermLab fs l)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< RewriteM m (TermLab fs) l
f TermLab fs l
t'