module FreeC.IR.Unification
  ( 
    UnificationError(..)
  , reportUnificationError
  , unifyOrFail
  , unifyAllOrFail
    
  , unify
  , unifyAll
  ) where
import           Control.Monad.Trans.Except     ( ExceptT, runExceptT, throwE )
import           Data.Composition               ( (.:) )
import           FreeC.Environment
import           FreeC.Environment.Entry
import           FreeC.Environment.LookupOrFail
import           FreeC.IR.SrcSpan
import           FreeC.IR.Subst
import           FreeC.IR.Subterm
import qualified FreeC.IR.Syntax                as IR
import           FreeC.IR.TypeSynExpansion
import           FreeC.Monad.Converter
import           FreeC.Monad.Reporter
import           FreeC.Pretty                   ( showPretty )
data UnificationError
  = UnificationError IR.Type IR.Type
  | OccursCheckFailure IR.TypeVarIdent IR.Type
  | RigidTypeVarError SrcSpan IR.TypeVarIdent IR.Type
reportUnificationError :: MonadReporter m => SrcSpan -> UnificationError -> m a
reportUnificationError srcSpan err = case err of
  UnificationError actualType expectedType -> reportFatal
    $ Message srcSpan Error
    $ "Could not match expected type `"
    ++ showPretty expectedType
    ++ "` with actual type `"
    ++ showPretty actualType
    ++ "`."
  OccursCheckFailure x u                   -> reportFatal
    $ Message srcSpan Error
    $ "Occurs check: Could not construct infinite type `"
    ++ showPretty x
    ++ "` ~ `"
    ++ showPretty u
    ++ "`."
  RigidTypeVarError xSrcSpan x u           -> reportFatal
    $ Message srcSpan Error
    $ "Could not match rigid type variable `"
    ++ x
    ++ "` (bound at `"
    ++ showPretty xSrcSpan
    ++ "`) with type `"
    ++ showPretty u
    ++ "`."
runOrFail :: SrcSpan -> ExceptT UnificationError Converter a -> Converter a
runOrFail srcSpan mx = runExceptT mx
  >>= either (reportUnificationError srcSpan) return
unifyOrFail :: SrcSpan -> IR.Type -> IR.Type -> Converter (Subst IR.Type)
unifyOrFail srcSpan = runOrFail srcSpan .: unify
unifyAllOrFail :: SrcSpan -> [IR.Type] -> Converter (Subst IR.Type)
unifyAllOrFail srcSpan = runOrFail srcSpan . unifyAll
unify
  :: IR.Type -> IR.Type -> ExceptT UnificationError Converter (Subst IR.Type)
unify t s = do
  ds <- lift $ disagreementSet t s
  case ds of
    Nothing -> return identitySubst
    Just (_, u@(IR.TypeVar _ x), v@(IR.TypeVar _ y))
      | IR.isInternalIdent x -> x `mapsTo` v
      | IR.isInternalIdent y -> y `mapsTo` u
    Just (_, IR.TypeVar _ x, v) -> x `mapsTo` v
    Just (_, u, IR.TypeVar _ y) -> y `mapsTo` u
    Just (pos, u, v) -> do
      t' <- lift $ expandTypeSynonymAt pos t
      s' <- lift $ expandTypeSynonymAt pos s
      if t /= t' || s /= s' then unify t' s' else throwE $ UnificationError u v
 where
  
  
  mapsTo :: IR.TypeVarIdent
         -> IR.Type
         -> ExceptT UnificationError Converter (Subst IR.Type)
  x `mapsTo` u = do
    rigidCheck
    occursCheck u
    let subst = singleSubst (IR.UnQual (IR.Ident x)) u
        t'    = applySubst subst t
        s'    = applySubst subst s
    mgu <- unify t' s'
    return (composeSubst mgu subst)
   where
    
    
    
    
    rigidCheck :: ExceptT UnificationError Converter ()
    rigidCheck = do
      maybeEntry
        <- lift $ inEnv $ lookupEntry IR.TypeScope (IR.UnQual (IR.Ident x))
      case maybeEntry of
        Nothing    -> return ()
        Just entry -> throwE $ RigidTypeVarError (entrySrcSpan entry) x u
    
    
    
    occursCheck :: IR.Type -> ExceptT UnificationError Converter ()
    occursCheck (IR.TypeVar _ y)
      | x == y = throwE $ OccursCheckFailure x u
      | otherwise = return ()
    occursCheck (IR.TypeCon _ _)      = return ()
    occursCheck (IR.TypeApp _ t1 t2)  = occursCheck t1 >> occursCheck t2
    occursCheck (IR.FuncType _ t1 t2) = occursCheck t1 >> occursCheck t2
unifyAll :: [IR.Type] -> ExceptT UnificationError Converter (Subst IR.Type)
unifyAll []             = return identitySubst
unifyAll [_]            = return identitySubst
unifyAll (t0 : t1 : ts) = do
  mgu <- unify t0 t1
  let t1' = applySubst mgu t1
  mgu' <- unifyAll (t1' : ts)
  return (composeSubst mgu mgu')
type DisagreementSet = Maybe (Pos, IR.Type, IR.Type)
disagreementSet :: IR.Type -> IR.Type -> Converter DisagreementSet
disagreementSet (IR.TypeVar _ x) (IR.TypeVar _ y) | x == y = return Nothing
disagreementSet t@(IR.TypeCon _ c) s@(IR.TypeCon _ d)
  | c == d = return Nothing
  | otherwise = do
    e <- lookupEntryOrFail (IR.typeSrcSpan t) IR.TypeScope c
    f <- lookupEntryOrFail (IR.typeSrcSpan s) IR.TypeScope d
    let n = entryName e
        m = entryName f
    if n == m then return Nothing else return (Just (rootPos, t, s))
disagreementSet (IR.TypeApp _ t1 t2) (IR.TypeApp _ s1 s2)
  = disagreementSet' 1 [t1, t2] [s1, s2]
disagreementSet (IR.FuncType _ t1 t2) (IR.FuncType _ s1 s2)
  = disagreementSet' 1 [t1, t2] [s1, s2]
disagreementSet t s = return (Just (rootPos, t, s))
disagreementSet' :: Int -> [IR.Type] -> [IR.Type] -> Converter DisagreementSet
disagreementSet' i (t : ts) (s : ss) = do
  ds <- disagreementSet t s
  case ds of
    Nothing            -> disagreementSet' (i + 1) ts ss
    Just (pos, t', s') -> return (Just (consPos i pos, t', s'))
disagreementSet' _ _ _               = return Nothing