module FreeC.IR.Unification
(
UnificationError(..)
, reportUnificationError
, unifyOrFail
, unifyAllOrFail
, unify
, unifyAll
) where
import Control.Monad.Trans.Except ( ExceptT, runExceptT, throwE )
import Data.Composition ( (.:) )
import FreeC.Environment
import FreeC.Environment.Entry
import FreeC.Environment.LookupOrFail
import FreeC.IR.SrcSpan
import FreeC.IR.Subst
import FreeC.IR.Subterm
import qualified FreeC.IR.Syntax as IR
import FreeC.IR.TypeSynExpansion
import FreeC.Monad.Converter
import FreeC.Monad.Reporter
import FreeC.Pretty ( showPretty )
data UnificationError
= UnificationError IR.Type IR.Type
| OccursCheckFailure IR.TypeVarIdent IR.Type
| RigidTypeVarError SrcSpan IR.TypeVarIdent IR.Type
reportUnificationError :: MonadReporter m => SrcSpan -> UnificationError -> m a
reportUnificationError srcSpan err = case err of
UnificationError actualType expectedType -> reportFatal
$ Message srcSpan Error
$ "Could not match expected type `"
++ showPretty expectedType
++ "` with actual type `"
++ showPretty actualType
++ "`."
OccursCheckFailure x u -> reportFatal
$ Message srcSpan Error
$ "Occurs check: Could not construct infinite type `"
++ showPretty x
++ "` ~ `"
++ showPretty u
++ "`."
RigidTypeVarError xSrcSpan x u -> reportFatal
$ Message srcSpan Error
$ "Could not match rigid type variable `"
++ x
++ "` (bound at `"
++ showPretty xSrcSpan
++ "`) with type `"
++ showPretty u
++ "`."
runOrFail :: SrcSpan -> ExceptT UnificationError Converter a -> Converter a
runOrFail srcSpan mx = runExceptT mx
>>= either (reportUnificationError srcSpan) return
unifyOrFail :: SrcSpan -> IR.Type -> IR.Type -> Converter (Subst IR.Type)
unifyOrFail srcSpan = runOrFail srcSpan .: unify
unifyAllOrFail :: SrcSpan -> [IR.Type] -> Converter (Subst IR.Type)
unifyAllOrFail srcSpan = runOrFail srcSpan . unifyAll
unify
:: IR.Type -> IR.Type -> ExceptT UnificationError Converter (Subst IR.Type)
unify t s = do
ds <- lift $ disagreementSet t s
case ds of
Nothing -> return identitySubst
Just (_, u@(IR.TypeVar _ x), v@(IR.TypeVar _ y))
| IR.isInternalIdent x -> x `mapsTo` v
| IR.isInternalIdent y -> y `mapsTo` u
Just (_, IR.TypeVar _ x, v) -> x `mapsTo` v
Just (_, u, IR.TypeVar _ y) -> y `mapsTo` u
Just (pos, u, v) -> do
t' <- lift $ expandTypeSynonymAt pos t
s' <- lift $ expandTypeSynonymAt pos s
if t /= t' || s /= s' then unify t' s' else throwE $ UnificationError u v
where
mapsTo :: IR.TypeVarIdent
-> IR.Type
-> ExceptT UnificationError Converter (Subst IR.Type)
x `mapsTo` u = do
rigidCheck
occursCheck u
let subst = singleSubst (IR.UnQual (IR.Ident x)) u
t' = applySubst subst t
s' = applySubst subst s
mgu <- unify t' s'
return (composeSubst mgu subst)
where
rigidCheck :: ExceptT UnificationError Converter ()
rigidCheck = do
maybeEntry
<- lift $ inEnv $ lookupEntry IR.TypeScope (IR.UnQual (IR.Ident x))
case maybeEntry of
Nothing -> return ()
Just entry -> throwE $ RigidTypeVarError (entrySrcSpan entry) x u
occursCheck :: IR.Type -> ExceptT UnificationError Converter ()
occursCheck (IR.TypeVar _ y)
| x == y = throwE $ OccursCheckFailure x u
| otherwise = return ()
occursCheck (IR.TypeCon _ _) = return ()
occursCheck (IR.TypeApp _ t1 t2) = occursCheck t1 >> occursCheck t2
occursCheck (IR.FuncType _ t1 t2) = occursCheck t1 >> occursCheck t2
unifyAll :: [IR.Type] -> ExceptT UnificationError Converter (Subst IR.Type)
unifyAll [] = return identitySubst
unifyAll [_] = return identitySubst
unifyAll (t0 : t1 : ts) = do
mgu <- unify t0 t1
let t1' = applySubst mgu t1
mgu' <- unifyAll (t1' : ts)
return (composeSubst mgu mgu')
type DisagreementSet = Maybe (Pos, IR.Type, IR.Type)
disagreementSet :: IR.Type -> IR.Type -> Converter DisagreementSet
disagreementSet (IR.TypeVar _ x) (IR.TypeVar _ y) | x == y = return Nothing
disagreementSet t@(IR.TypeCon _ c) s@(IR.TypeCon _ d)
| c == d = return Nothing
| otherwise = do
e <- lookupEntryOrFail (IR.typeSrcSpan t) IR.TypeScope c
f <- lookupEntryOrFail (IR.typeSrcSpan s) IR.TypeScope d
let n = entryName e
m = entryName f
if n == m then return Nothing else return (Just (rootPos, t, s))
disagreementSet (IR.TypeApp _ t1 t2) (IR.TypeApp _ s1 s2)
= disagreementSet' 1 [t1, t2] [s1, s2]
disagreementSet (IR.FuncType _ t1 t2) (IR.FuncType _ s1 s2)
= disagreementSet' 1 [t1, t2] [s1, s2]
disagreementSet t s = return (Just (rootPos, t, s))
disagreementSet' :: Int -> [IR.Type] -> [IR.Type] -> Converter DisagreementSet
disagreementSet' i (t : ts) (s : ss) = do
ds <- disagreementSet t s
case ds of
Nothing -> disagreementSet' (i + 1) ts ss
Just (pos, t', s') -> return (Just (consPos i pos, t', s'))
disagreementSet' _ _ _ = return Nothing