{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Count occurrences of node pairs (parent-child type combinations) in a tree.
-- Node pairs are tracked at the constructor level, not the fragment level.
-- For example, a fragment @Stat@ with constructors @Assign@ and @Do@ will
-- generate pairs like @(Assign, Var)@ and @(Do, BlockIsBlock)@, not @(Stat, Var)@.
module Cubix.Analysis.NodePairs
  ( NodePair(..)
  , countNodePairs
  , countNodePairsWithPossible
  , countNodePairsInFolder
  , possibleNodePairs
  ) where

import Control.Exception (SomeException, try)
import Control.Monad (filterM, foldM)
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Set (Set)
import Data.Set qualified as Set
import System.Directory (doesDirectoryExist, doesFileExist, listDirectory)
import System.FilePath ((</>), takeExtension)

import Language.Haskell.TH

import Data.Comp.Multi (All, ConstrNameHF(..), HFoldable(..), HFunctor, Term, caseCxt, unTerm)

import Cubix.Language.Parametric.Derive (sumToNames)

-- #######################################
-- Core Types
-- #######################################

-- | A node pair represents a parent-child relationship by constructor name.
-- If a node with constructor A has a child with constructor B, that's an (A, B) pair.
data NodePair = NodePair
  { NodePair -> [Char]
parentType :: String
  , NodePair -> [Char]
childType  :: String
  } deriving (NodePair -> NodePair -> Bool
(NodePair -> NodePair -> Bool)
-> (NodePair -> NodePair -> Bool) -> Eq NodePair
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NodePair -> NodePair -> Bool
== :: NodePair -> NodePair -> Bool
$c/= :: NodePair -> NodePair -> Bool
/= :: NodePair -> NodePair -> Bool
Eq, Eq NodePair
Eq NodePair =>
(NodePair -> NodePair -> Ordering)
-> (NodePair -> NodePair -> Bool)
-> (NodePair -> NodePair -> Bool)
-> (NodePair -> NodePair -> Bool)
-> (NodePair -> NodePair -> Bool)
-> (NodePair -> NodePair -> NodePair)
-> (NodePair -> NodePair -> NodePair)
-> Ord NodePair
NodePair -> NodePair -> Bool
NodePair -> NodePair -> Ordering
NodePair -> NodePair -> NodePair
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: NodePair -> NodePair -> Ordering
compare :: NodePair -> NodePair -> Ordering
$c< :: NodePair -> NodePair -> Bool
< :: NodePair -> NodePair -> Bool
$c<= :: NodePair -> NodePair -> Bool
<= :: NodePair -> NodePair -> Bool
$c> :: NodePair -> NodePair -> Bool
> :: NodePair -> NodePair -> Bool
$c>= :: NodePair -> NodePair -> Bool
>= :: NodePair -> NodePair -> Bool
$cmax :: NodePair -> NodePair -> NodePair
max :: NodePair -> NodePair -> NodePair
$cmin :: NodePair -> NodePair -> NodePair
min :: NodePair -> NodePair -> NodePair
Ord, Int -> NodePair -> ShowS
[NodePair] -> ShowS
NodePair -> [Char]
(Int -> NodePair -> ShowS)
-> (NodePair -> [Char]) -> ([NodePair] -> ShowS) -> Show NodePair
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NodePair -> ShowS
showsPrec :: Int -> NodePair -> ShowS
$cshow :: NodePair -> [Char]
show :: NodePair -> [Char]
$cshowList :: [NodePair] -> ShowS
showList :: [NodePair] -> ShowS
Show)

-- #######################################
-- File System Utilities
-- #######################################

-- | Get all files with a given extension from a directory (recursive)
getFilesWithExtension :: String -> FilePath -> IO [FilePath]
getFilesWithExtension :: [Char] -> [Char] -> IO [[Char]]
getFilesWithExtension [Char]
ext [Char]
dir = do
  [[Char]]
entries <- [Char] -> IO [[Char]]
listDirectory [Char]
dir
  let fullPaths :: [[Char]]
fullPaths = ShowS -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map ([Char]
dir [Char] -> ShowS
</>) [[Char]]
entries
  [[Char]]
files <- ([Char] -> IO Bool) -> [[Char]] -> IO [[Char]]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM [Char] -> IO Bool
doesFileExist [[Char]]
fullPaths
  [[Char]]
dirs <- ([Char] -> IO Bool) -> [[Char]] -> IO [[Char]]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM [Char] -> IO Bool
doesDirectoryExist [[Char]]
fullPaths
  let matchingFiles :: [[Char]]
matchingFiles = ([Char] -> Bool) -> [[Char]] -> [[Char]]
forall a. (a -> Bool) -> [a] -> [a]
filter (\[Char]
f -> ShowS
takeExtension [Char]
f [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
ext) [[Char]]
files
  [[Char]]
subFiles <- [[[Char]]] -> [[Char]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[Char]]] -> [[Char]]) -> IO [[[Char]]] -> IO [[Char]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Char] -> IO [[Char]]) -> [[Char]] -> IO [[[Char]]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> [Char] -> IO [[Char]]
getFilesWithExtension [Char]
ext) [[Char]]
dirs
  return ([[Char]]
matchingFiles [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]]
subFiles)

-- #######################################
-- Runtime Node Pair Analysis
-- #######################################

-- | Wrapper functor base names that should be treated as pass-through.
-- These don't create node pairs themselves; their children are treated
-- as direct children of the grandparent.
-- Matched by base name (after the last '.') since constructor names are now
-- fully qualified.
passThroughBaseNames :: Set String
passThroughBaseNames :: Set [Char]
passThroughBaseNames = [[Char]] -> Set [Char]
forall a. Ord a => [a] -> Set a
Set.fromList
  [ [Char]
"ListF", [Char]
"ConsF", [Char]
"NilF"
  , [Char]
"MaybeF", [Char]
"JustF", [Char]
"NothingF"
  , [Char]
"PairF", [Char]
"TripleF"
  , [Char]
"EitherF", [Char]
"LeftF", [Char]
"RightF"
  ]

-- | Check if a (possibly qualified) type name is a pass-through functor.
-- Extracts the base name (after last '.') and checks against known pass-throughs.
isPassThrough :: String -> Bool
isPassThrough :: [Char] -> Bool
isPassThrough [Char]
name = [Char] -> Set [Char] -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (ShowS
baseName [Char]
name) Set [Char]
passThroughBaseNames
  where
    baseName :: ShowS
baseName [Char]
s = case (Char -> Bool) -> [Char] -> ([Char], [Char])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'.') [Char]
s of
      ([Char]
_, [])    -> [Char]
s          -- No dots, already a base name
      ([Char]
_, Char
_:[Char]
rest) -> ShowS
baseName [Char]
rest  -- Skip past the dot, recurse

-- | Count all node pairs in a term.
-- Returns a map from each node pair to its count in the tree.
-- Wrapper functors (ListF, MaybeF, PairF, etc.) are treated as pass-through:
-- their children are counted as direct children of the grandparent.
-- Parent names are constructor names (e.g., "Assign", "Do"), not fragment names.
countNodePairs :: forall fs l.
  ( All HFunctor fs
  , All HFoldable fs
  , All ConstrNameHF fs
  ) => Term fs l -> Map NodePair Int
countNodePairs :: forall (fs :: [(* -> *) -> * -> *]) l.
(All HFunctor fs, All HFoldable fs, All ConstrNameHF fs) =>
Term fs l -> Map NodePair Int
countNodePairs = Maybe [Char] -> forall i. Term fs i -> Map NodePair Int
countWithParent Maybe [Char]
forall a. Maybe a
Nothing
  where
    countWithParent :: Maybe String -> forall i. Term fs i -> Map NodePair Int
    countWithParent :: Maybe [Char] -> forall i. Term fs i -> Map NodePair Int
countWithParent Maybe [Char]
mparent Term fs i
t =
      let nodeName :: [Char]
nodeName = forall (cxt :: ((* -> *) -> * -> *) -> Constraint)
       (fs :: [(* -> *) -> * -> *]) (a :: * -> *) e b.
All cxt fs =>
(forall (f :: (* -> *) -> * -> *). cxt f => f a e -> b)
-> Sum fs a e -> b
caseCxt @ConstrNameHF f (HFix (Sum fs)) i -> [Char]
forall (e :: * -> *) l. f e l -> [Char]
forall (f :: (* -> *) -> * -> *).
ConstrNameHF f =>
f (HFix (Sum fs)) i -> [Char]
forall (f :: (* -> *) -> * -> *) (e :: * -> *) l.
ConstrNameHF f =>
f e l -> [Char]
constrNameHF (Term fs i -> Sum fs (HFix (Sum fs)) i
forall (f :: (* -> *) -> * -> *) t. HFix f t -> f (HFix f) t
unTerm Term fs i
t)
      in if [Char] -> Bool
isPassThrough [Char]
nodeName
         then -- Pass-through: don't create pair, pass parent through to children
              (forall i. Term fs i -> Map NodePair Int)
-> Sum fs (HFix (Sum fs)) :=> Map NodePair Int
forall m (a :: * -> *). Monoid m => (a :=> m) -> Sum fs a :=> m
forall (h :: (* -> *) -> * -> *) m (a :: * -> *).
(HFoldable h, Monoid m) =>
(a :=> m) -> h a :=> m
hfoldMap (Maybe [Char] -> forall i. Term fs i -> Map NodePair Int
countWithParent Maybe [Char]
mparent) (Term fs i -> Sum fs (HFix (Sum fs)) i
forall (f :: (* -> *) -> * -> *) t. HFix f t -> f (HFix f) t
unTerm Term fs i
t)
         else -- Regular node: create pair with parent (if any) and recurse with self as parent
              let pairMap :: Map NodePair Int
pairMap = case Maybe [Char]
mparent of
                    Maybe [Char]
Nothing -> Map NodePair Int
forall k a. Map k a
Map.empty
                    Just [Char]
p  -> NodePair -> Int -> Map NodePair Int
forall k a. k -> a -> Map k a
Map.singleton ([Char] -> [Char] -> NodePair
NodePair [Char]
p [Char]
nodeName) Int
1
                  childMap :: Map NodePair Int
childMap = (forall i. Term fs i -> Map NodePair Int)
-> Sum fs (HFix (Sum fs)) :=> Map NodePair Int
forall m (a :: * -> *). Monoid m => (a :=> m) -> Sum fs a :=> m
forall (h :: (* -> *) -> * -> *) m (a :: * -> *).
(HFoldable h, Monoid m) =>
(a :=> m) -> h a :=> m
hfoldMap (Maybe [Char] -> forall i. Term fs i -> Map NodePair Int
countWithParent ([Char] -> Maybe [Char]
forall a. a -> Maybe a
Just [Char]
nodeName)) (Term fs i -> Sum fs (HFix (Sum fs)) i
forall (f :: (* -> *) -> * -> *) t. HFix f t -> f (HFix f) t
unTerm Term fs i
t)
              in (Int -> Int -> Int)
-> Map NodePair Int -> Map NodePair Int -> Map NodePair Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Map NodePair Int
pairMap Map NodePair Int
childMap

-- | Like 'countNodePairs', but starts with all possible pairs mapped to 0.
-- Pairs that occur in the term get their actual counts; pairs that don't occur
-- remain at 0.
countNodePairsWithPossible :: forall fs l.
  ( All HFunctor fs
  , All HFoldable fs
  , All ConstrNameHF fs
  ) => Set NodePair -> Term fs l -> Map NodePair Int
countNodePairsWithPossible :: forall (fs :: [(* -> *) -> * -> *]) l.
(All HFunctor fs, All HFoldable fs, All ConstrNameHF fs) =>
Set NodePair -> Term fs l -> Map NodePair Int
countNodePairsWithPossible Set NodePair
possiblePairs Term fs l
term =
  (Int -> Int -> Int)
-> Map NodePair Int -> Map NodePair Int -> Map NodePair Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Map NodePair Int
zeroCounts (Term fs l -> Map NodePair Int
forall (fs :: [(* -> *) -> * -> *]) l.
(All HFunctor fs, All HFoldable fs, All ConstrNameHF fs) =>
Term fs l -> Map NodePair Int
countNodePairs Term fs l
term)
  where
    zeroCounts :: Map NodePair Int
zeroCounts = (NodePair -> Int) -> Set NodePair -> Map NodePair Int
forall k a. (k -> a) -> Set k -> Map k a
Map.fromSet (Int -> NodePair -> Int
forall a b. a -> b -> a
const Int
0) Set NodePair
possiblePairs

-- | Count node pairs across all files in a folder matching a given extension.
--
-- Takes:
--   * A set of all possible node pairs (to initialize counts to 0)
--   * A file extension to filter by (e.g., ".lua", ".move")
--   * A parse function that returns @Maybe (Term fs l)@
--   * A folder path
--
-- Returns a map with all possible pairs initialized to 0, then incremented
-- based on actual occurrences in the parsed files.
countNodePairsInFolder :: forall fs l.
  ( All HFunctor fs
  , All HFoldable fs
  , All ConstrNameHF fs
  ) => Set NodePair
    -> String
    -> (FilePath -> IO (Maybe (Term fs l)))
    -> FilePath
    -> IO (Map NodePair Int)
countNodePairsInFolder :: forall (fs :: [(* -> *) -> * -> *]) l.
(All HFunctor fs, All HFoldable fs, All ConstrNameHF fs) =>
Set NodePair
-> [Char]
-> ([Char] -> IO (Maybe (Term fs l)))
-> [Char]
-> IO (Map NodePair Int)
countNodePairsInFolder Set NodePair
possiblePairs [Char]
extension [Char] -> IO (Maybe (Term fs l))
parseFile [Char]
folder = do
  [[Char]]
files <- [Char] -> [Char] -> IO [[Char]]
getFilesWithExtension [Char]
extension [Char]
folder
  (Map NodePair Int -> [Char] -> IO (Map NodePair Int))
-> Map NodePair Int -> [[Char]] -> IO (Map NodePair Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Map NodePair Int -> [Char] -> IO (Map NodePair Int)
countFile Map NodePair Int
forall k a. Map k a
Map.empty [[Char]]
files
  where
    countFile :: Map NodePair Int -> FilePath -> IO (Map NodePair Int)
    countFile :: Map NodePair Int -> [Char] -> IO (Map NodePair Int)
countFile Map NodePair Int
acc [Char]
filePath = do
      Either SomeException (Maybe (Term fs l))
result <- IO (Maybe (Term fs l))
-> IO (Either SomeException (Maybe (Term fs l)))
forall e a. Exception e => IO a -> IO (Either e a)
try ([Char] -> IO (Maybe (Term fs l))
parseFile [Char]
filePath) :: IO (Either SomeException (Maybe (Term fs l)))
      case Either SomeException (Maybe (Term fs l))
result of
        Left SomeException
_ -> Map NodePair Int -> IO (Map NodePair Int)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Map NodePair Int
acc  -- Skip files that throw exceptions
        Right Maybe (Term fs l)
Nothing -> Map NodePair Int -> IO (Map NodePair Int)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Map NodePair Int
acc
        Right (Just Term fs l
term) -> Map NodePair Int -> IO (Map NodePair Int)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Map NodePair Int -> IO (Map NodePair Int))
-> Map NodePair Int -> IO (Map NodePair Int)
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int)
-> Map NodePair Int -> Map NodePair Int -> Map NodePair Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Map NodePair Int
acc (Set NodePair -> Term fs l -> Map NodePair Int
forall (fs :: [(* -> *) -> * -> *]) l.
(All HFunctor fs, All HFoldable fs, All ConstrNameHF fs) =>
Set NodePair -> Term fs l -> Map NodePair Int
countNodePairsWithPossible Set NodePair
possiblePairs Term fs l
term)

-- #######################################
-- Template Haskell: Static Node Pair Analysis
-- #######################################

-- | Extract all possible node pairs from a signature type synonym.
-- This uses Template Haskell to statically analyze the type definitions.
--
-- Usage: @$(possibleNodePairs ''MLuaSig)@
--
-- The signature type (e.g., @MLuaSig@) is a promoted type-level list like
-- @'[Fragment1, Fragment2, ...]@, created by @makeSumType@.
-- This function reifies the type synonym to extract the fragment names.
--
-- Returns a @Set NodePair@ containing pairs where:
--
--   * 'parentType' is the constructor name (e.g., "Assign", "Do", "Block")
--   * 'childType' is resolved to actual constructor names, unwrapping List/Maybe/etc.
--
-- For example, if Block has a child of type @[StatementL]@, and Statement produces
-- @StatementL@, then this generates a pair (Block, Statement).
-- Wrapper types (List, Maybe, Either, Pair) are treated as pass-through.
-- Sort labels are fully qualified to avoid collisions between identically-named
-- sorts from different modules (e.g., Lua's @BlockL@ vs parametric @BlockL@).
possibleNodePairs :: Name -> Q Exp
possibleNodePairs :: Name -> Q Exp
possibleNodePairs Name
sigName = do
  [Name]
fragNames <- Name -> Q [Name]
sumToNames Name
sigName
  -- Build a map from fully-qualified sort labels to constructor names that produce them
  Map [Char] [[Char]]
sortToConstrs <- [Name] -> Q (Map [Char] [[Char]])
buildSortToConstrMap [Name]
fragNames
  -- Get all pairs, resolving sorts to actual constructor names
  [([Char], [Char])]
allPairs <- [[([Char], [Char])]] -> [([Char], [Char])]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[([Char], [Char])]] -> [([Char], [Char])])
-> Q [[([Char], [Char])]] -> Q [([Char], [Char])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Name -> Q [([Char], [Char])]) -> [Name] -> Q [[([Char], [Char])]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map [Char] [[Char]] -> Name -> Q [([Char], [Char])]
getFragmentPairs Map [Char] [[Char]]
sortToConstrs) [Name]
fragNames
  let pairExps :: [Q Exp]
pairExps = (([Char], [Char]) -> Q Exp) -> [([Char], [Char])] -> [Q Exp]
forall a b. (a -> b) -> [a] -> [b]
map ([Char], [Char]) -> Q Exp
mkPairExp [([Char], [Char])]
allPairs
  [| Set.fromList $([Q Exp] -> Q Exp
forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE [Q Exp]
pairExps) |]

-- ##############################
-- Sort-to-Constructor Mapping
-- ##############################

-- | Build a map from fully-qualified sort labels (e.g., "Cubix...VarDecl.BlockL")
-- to constructor names that produce them.
-- Uses fully-qualified sort names to avoid collisions between sorts with the same
-- base name from different modules.
buildSortToConstrMap :: [Name] -> Q (Map String [String])
buildSortToConstrMap :: [Name] -> Q (Map [Char] [[Char]])
buildSortToConstrMap [Name]
fragNames = do
  [([Char], [Char])]
pairs <- [[([Char], [Char])]] -> [([Char], [Char])]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[([Char], [Char])]] -> [([Char], [Char])])
-> Q [[([Char], [Char])]] -> Q [([Char], [Char])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Name -> Q [([Char], [Char])]) -> [Name] -> Q [[([Char], [Char])]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Name -> Q [([Char], [Char])]
getConstrSorts [Name]
fragNames
  return $ ([[Char]] -> [[Char]] -> [[Char]])
-> [([Char], [[Char]])] -> Map [Char] [[Char]]
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
(++) [([Char]
sort, [[Char]
constr]) | ([Char]
constr, [Char]
sort) <- [([Char], [Char])]
pairs]

-- | Get all (constructorName, qualifiedSortLabel) pairs for a fragment.
-- Each constructor of the fragment gets its own entry.
getConstrSorts :: Name -> Q [(String, String)]
getConstrSorts :: Name -> Q [([Char], [Char])]
getConstrSorts Name
fragName = do
  Info
info <- Name -> Q Info
reify Name
fragName
  case Info
info of
    TyConI (DataD Cxt
_ Name
_ [TyVarBndr BndrVis]
_ Maybe Kind
_ [Con]
constrs [DerivClause]
_) ->
      [([Char], [Char])] -> Q [([Char], [Char])]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([([Char], [Char])] -> Q [([Char], [Char])])
-> [([Char], [Char])] -> Q [([Char], [Char])]
forall a b. (a -> b) -> a -> b
$ (Con -> [([Char], [Char])]) -> [Con] -> [([Char], [Char])]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Con -> [([Char], [Char])]
getConstrSort [Con]
constrs
    TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr BndrVis]
_ Maybe Kind
_ Con
constr [DerivClause]
_) ->
      [([Char], [Char])] -> Q [([Char], [Char])]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([([Char], [Char])] -> Q [([Char], [Char])])
-> [([Char], [Char])] -> Q [([Char], [Char])]
forall a b. (a -> b) -> a -> b
$ Con -> [([Char], [Char])]
getConstrSort Con
constr
    Info
_ -> [([Char], [Char])] -> Q [([Char], [Char])]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return []

-- | Get the (constructorName, qualifiedSort) pair from a single constructor
getConstrSort :: Con -> [(String, String)]
getConstrSort :: Con -> [([Char], [Char])]
getConstrSort Con
con =
  let (Name
conName, Cxt
_, Maybe [Char]
mRetSort) = Con -> (Name, Cxt, Maybe [Char])
normalCon' Con
con
      conStr :: [Char]
conStr = Name -> [Char]
qualifiedName Name
conName
  in case Maybe [Char]
mRetSort of
       Just [Char]
sort -> [([Char]
conStr, [Char]
sort)]
       Maybe [Char]
Nothing -> []

-- ##############################
-- Fragment Analysis
-- ##############################

-- | Get all (parent, child) pairs for a single fragment, resolving sorts to constructors.
-- The parent name is the constructor name, not the fragment name.
getFragmentPairs :: Map String [String] -> Name -> Q [(String, String)]
getFragmentPairs :: Map [Char] [[Char]] -> Name -> Q [([Char], [Char])]
getFragmentPairs Map [Char] [[Char]]
sortToConstrs Name
fragName = do
  Info
info <- Name -> Q Info
reify Name
fragName
  case Info
info of
    TyConI (DataD Cxt
_ Name
_ [TyVarBndr BndrVis]
_ Maybe Kind
_ [Con]
constrs [DerivClause]
_) ->
      [[([Char], [Char])]] -> [([Char], [Char])]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[([Char], [Char])]] -> [([Char], [Char])])
-> Q [[([Char], [Char])]] -> Q [([Char], [Char])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Con -> Q [([Char], [Char])]) -> [Con] -> Q [[([Char], [Char])]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map [Char] [[Char]] -> Con -> Q [([Char], [Char])]
getConstrPairs Map [Char] [[Char]]
sortToConstrs) [Con]
constrs
    TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr BndrVis]
_ Maybe Kind
_ Con
constr [DerivClause]
_) ->
      Map [Char] [[Char]] -> Con -> Q [([Char], [Char])]
getConstrPairs Map [Char] [[Char]]
sortToConstrs Con
constr
    Info
_ -> do
      [Char] -> Q ()
reportWarning ([Char] -> Q ()) -> [Char] -> Q ()
forall a b. (a -> b) -> a -> b
$ [Char]
"possibleNodePairs: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Show a => a -> [Char]
show Name
fragName [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" is not a data type"
      return []

-- | Get pairs for a single constructor.
-- The parent name is the constructor name (e.g., "Assign"), not the fragment name (e.g., "Stat").
getConstrPairs :: Map String [String] -> Con -> Q [(String, String)]
getConstrPairs :: Map [Char] [[Char]] -> Con -> Q [([Char], [Char])]
getConstrPairs Map [Char] [[Char]]
sortToConstrs Con
con = do
  let (Name
conName, Cxt
argTypes, Maybe [Char]
_) = Con -> (Name, Cxt, Maybe [Char])
normalCon' Con
con
      parentName :: [Char]
parentName = Name -> [Char]
qualifiedName Name
conName
  [[Char]]
childSorts <- [[[Char]]] -> [[Char]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[Char]]] -> [[Char]]) -> Q [[[Char]]] -> Q [[Char]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Kind -> Q [[Char]]) -> Cxt -> Q [[[Char]]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Kind -> Q [[Char]]
extractChildSorts Cxt
argTypes
  -- Resolve each sort to its possible constructor types
  let resolveSort :: [Char] -> [[Char]]
resolveSort [Char]
sort = [[Char]] -> [Char] -> Map [Char] [[Char]] -> [[Char]]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] [Char]
sort Map [Char] [[Char]]
sortToConstrs
      childConstrs :: [[Char]]
childConstrs = ([Char] -> [[Char]]) -> [[Char]] -> [[Char]]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [Char] -> [[Char]]
resolveSort [[Char]]
childSorts
  [([Char], [Char])] -> Q [([Char], [Char])]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
return [([Char]
parentName, [Char]
child) | [Char]
child <- [[Char]]
childConstrs]

-- ##############################
-- Type Extraction
-- ##############################

-- | Extract child sort names from a type, unwrapping List/Maybe/Either/Pair.
-- Returns fully-qualified sort names to avoid collisions.
extractChildSorts :: Type -> Q [String]
extractChildSorts :: Kind -> Q [[Char]]
extractChildSorts Kind
ty = [[Char]] -> Q [[Char]]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([[Char]] -> Q [[Char]]) -> [[Char]] -> Q [[Char]]
forall a b. (a -> b) -> a -> b
$ Kind -> [[Char]]
go Kind
ty
  where
    go :: Kind -> [[Char]]
go (AppT (VarT Name
_) Kind
sortTy) = Kind -> [[Char]]
unwrapSort Kind
sortTy
    go (AppT Kind
t1 Kind
_) = Kind -> [[Char]]
go Kind
t1
    go Kind
_ = []

    -- Unwrap wrapper types to get the inner sort(s)
    -- Uses fully-qualified names via 'qualifiedName'
    unwrapSort :: Type -> [String]
    unwrapSort :: Kind -> [[Char]]
unwrapSort (ConT Name
n) = [Name -> [Char]
qualifiedName Name
n]
    unwrapSort (VarT Name
_) = []  -- Polymorphic sort - skip for now
    -- List types: [SortL] or [] SortL
    unwrapSort (AppT Kind
ListT Kind
inner) = Kind -> [[Char]]
unwrapSort Kind
inner
    unwrapSort (AppT (ConT Name
n) Kind
inner)
      | Name -> [Char]
nameBase Name
n [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"[]" = Kind -> [[Char]]
unwrapSort Kind
inner
      | Name -> [Char]
nameBase Name
n [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"Maybe" = Kind -> [[Char]]
unwrapSort Kind
inner
      | Name -> [Char]
nameBase Name
n [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"Either" = Kind -> [[Char]]
unwrapSort Kind
inner  -- Will get both via AppT
      | Bool
otherwise = [Name -> [Char]
qualifiedName Name
n]  -- Some other type constructor
    -- Tuple/pair types
    unwrapSort (AppT Kind
t1 Kind
t2) = Kind -> [[Char]]
unwrapSort Kind
t1 [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ Kind -> [[Char]]
unwrapSort Kind
t2
    unwrapSort Kind
_ = []

-- ##############################
-- TH Helpers
-- ##############################

-- | Get a fully-qualified name from a TH Name.
-- Falls back to 'nameBase' if the name has no module qualifier
-- (e.g., for locally-generated names).
qualifiedName :: Name -> String
qualifiedName :: Name -> [Char]
qualifiedName Name
n = case Name -> Maybe [Char]
nameModule Name
n of
  Just [Char]
m  -> [Char]
m [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"." [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
nameBase Name
n
  Maybe [Char]
Nothing -> Name -> [Char]
nameBase Name
n

-- | Helper: normalize a constructor to get name, arg types, and return sort name.
-- For GADT constructors, the return sort is extracted from the explicit return type.
-- For ForallC with equality constraints (from createSortInclusionType),
-- we extract the return sort from the equality constraint (e.g., @i ~ NameL@).
-- Sort names are fully qualified.
normalCon' :: Con -> (Name, [Type], Maybe String)
normalCon' :: Con -> (Name, Cxt, Maybe [Char])
normalCon' (NormalC Name
constr [BangType]
args) = (Name
constr, (BangType -> Kind) -> [BangType] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map BangType -> Kind
forall a b. (a, b) -> b
snd [BangType]
args, Maybe [Char]
forall a. Maybe a
Nothing)
normalCon' (RecC Name
constr [VarBangType]
args) = (Name
constr, (VarBangType -> Kind) -> [VarBangType] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (\(Name
_,Bang
_,Kind
t) -> Kind
t) [VarBangType]
args, Maybe [Char]
forall a. Maybe a
Nothing)
normalCon' (InfixC BangType
a Name
constr BangType
b) = (Name
constr, [BangType -> Kind
forall a b. (a, b) -> b
snd BangType
a, BangType -> Kind
forall a b. (a, b) -> b
snd BangType
b], Maybe [Char]
forall a. Maybe a
Nothing)
normalCon' (ForallC [TyVarBndr Specificity]
_ Cxt
ctx Con
constr) =
  let (Name
cn, Cxt
args, Maybe [Char]
mRet) = Con -> (Name, Cxt, Maybe [Char])
normalCon' Con
constr
      retSort :: Maybe [Char]
retSort = case Maybe [Char]
mRet of
        Just [Char]
_  -> Maybe [Char]
mRet
        Maybe [Char]
Nothing -> Cxt -> Maybe [Char]
extractSortFromEqConstraints Cxt
ctx
  in (Name
cn, Cxt
args, Maybe [Char]
retSort)
normalCon' (GadtC (Name
constr:[Name]
_) [BangType]
args Kind
typ) = (Name
constr, (BangType -> Kind) -> [BangType] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map BangType -> Kind
forall a b. (a, b) -> b
snd [BangType]
args, Kind -> Maybe [Char]
extractSortFromReturnType Kind
typ)
normalCon' (GadtC [] [BangType]
_ Kind
_) = [Char] -> (Name, Cxt, Maybe [Char])
forall a. HasCallStack => [Char] -> a
error [Char]
"Empty GADT constructor list"
normalCon' (RecGadtC (Name
constr:[Name]
_) [VarBangType]
args Kind
typ) = (Name
constr, (VarBangType -> Kind) -> [VarBangType] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (\(Name
_,Bang
_,Kind
t) -> Kind
t) [VarBangType]
args, Kind -> Maybe [Char]
extractSortFromReturnType Kind
typ)
normalCon' (RecGadtC [] [VarBangType]
_ Kind
_) = [Char] -> (Name, Cxt, Maybe [Char])
forall a. HasCallStack => [Char] -> a
error [Char]
"Empty RecGADT constructor list"

-- | Extract sort label from a GADT return type like @Foo e SortL@
-- Returns a fully-qualified sort name.
extractSortFromReturnType :: Type -> Maybe String
extractSortFromReturnType :: Kind -> Maybe [Char]
extractSortFromReturnType Kind
ty = Kind -> Maybe [Char]
go Kind
ty
  where
    go :: Kind -> Maybe [Char]
go (AppT Kind
_ (ConT Name
n)) = [Char] -> Maybe [Char]
forall a. a -> Maybe a
Just (Name -> [Char]
qualifiedName Name
n)
    go (AppT Kind
t Kind
_) = Kind -> Maybe [Char]
go Kind
t
    go Kind
_ = Maybe [Char]
forall a. Maybe a
Nothing

-- | Extract a return sort from equality constraints in a ForallC.
-- deriveMultiComp and createSortInclusionType generate constructors like:
--   @ForallC [] [i ~ SortL] (NormalC ConName [...])@
-- GHC represents @~@ as either 'EqualityT' or @ConT ''(~)@ depending on version.
-- We handle both forms.
-- Returns a fully-qualified sort name.
extractSortFromEqConstraints :: [Type] -> Maybe String
extractSortFromEqConstraints :: Cxt -> Maybe [Char]
extractSortFromEqConstraints [] = Maybe [Char]
forall a. Maybe a
Nothing
extractSortFromEqConstraints (Kind
c : Cxt
rest) = case Kind -> Maybe [Char]
getEqSort Kind
c of
  Just [Char]
s  -> [Char] -> Maybe [Char]
forall a. a -> Maybe a
Just [Char]
s
  Maybe [Char]
Nothing -> Cxt -> Maybe [Char]
extractSortFromEqConstraints Cxt
rest
  where
    getEqSort :: Kind -> Maybe [Char]
getEqSort (AppT (AppT Kind
EqualityT (VarT Name
_)) (ConT Name
n)) = [Char] -> Maybe [Char]
forall a. a -> Maybe a
Just (Name -> [Char]
qualifiedName Name
n)
    getEqSort (AppT (AppT (ConT Name
eq)  (VarT Name
_)) (ConT Name
n))
      | Name -> [Char]
nameBase Name
eq [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
"~" = [Char] -> Maybe [Char]
forall a. a -> Maybe a
Just (Name -> [Char]
qualifiedName Name
n)
    getEqSort Kind
_ = Maybe [Char]
forall a. Maybe a
Nothing

-- | Make a NodePair expression
mkPairExp :: (String, String) -> Q Exp
mkPairExp :: ([Char], [Char]) -> Q Exp
mkPairExp ([Char]
parent, [Char]
child) =
  [| NodePair $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE ([Char] -> Lit
stringL [Char]
parent)) $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE ([Char] -> Lit
stringL [Char]
child)) |]