-- | Inference context tracking for better error messages
module Elara.TypeInfer.Context where

import Elara.AST.Name (Qualified, VarName)
import Elara.AST.Region (SourceRegion)
import Elara.Data.Pretty

{- | Describes the context in which type inference is happening.
This is used to provide better error messages by explaining
WHY two types are being compared.
-}
data InferenceContext typ
    = -- | Checking an argument to a function call
      CheckingFunctionArgument
        { forall typ. InferenceContext typ -> Int
argPosition :: !Int
        , forall typ. InferenceContext typ -> Maybe (Qualified VarName)
functionName :: !(Maybe (Qualified VarName))
        , forall typ. InferenceContext typ -> typ
functionType :: !typ
        , forall typ. InferenceContext typ -> typ
actualArgumentType :: !typ
        , forall typ. InferenceContext typ -> SourceRegion
callSite :: !SourceRegion
        }
    | -- | Checking the result type of a function call
      CheckingFunctionResult
        { callSite :: !SourceRegion
        }
    | -- | Checking the condition of an if expression
      CheckingIfCondition
        { forall typ. InferenceContext typ -> SourceRegion
ifSite :: !SourceRegion
        }
    | -- | Checking that if branches have the same type
      CheckingIfBranches
        { forall typ. InferenceContext typ -> SourceRegion
thenSite :: !SourceRegion
        , forall typ. InferenceContext typ -> SourceRegion
elseSite :: !SourceRegion
        }
    | -- | Checking a branch in a match expression
      CheckingMatchBranch
        { forall typ. InferenceContext typ -> Int
branchIndex :: !Int
        , forall typ. InferenceContext typ -> SourceRegion
branchSite :: !SourceRegion
        }
    | -- | Checking a pattern
      CheckingPattern
        { forall typ. InferenceContext typ -> SourceRegion
patternSite :: !SourceRegion
        }
    | -- | Checking a let binding
      CheckingLetBinding
        { forall typ. InferenceContext typ -> Qualified VarName
bindingName :: !(Qualified VarName)
        , forall typ. InferenceContext typ -> SourceRegion
bindingSite :: !SourceRegion
        }
    | -- | Checking against a type annotation
      CheckingAnnotation
        { forall typ. InferenceContext typ -> SourceRegion
annotationSite :: !SourceRegion
        }
    deriving (InferenceContext typ -> InferenceContext typ -> Bool
(InferenceContext typ -> InferenceContext typ -> Bool)
-> (InferenceContext typ -> InferenceContext typ -> Bool)
-> Eq (InferenceContext typ)
forall typ.
Eq typ =>
InferenceContext typ -> InferenceContext typ -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall typ.
Eq typ =>
InferenceContext typ -> InferenceContext typ -> Bool
== :: InferenceContext typ -> InferenceContext typ -> Bool
$c/= :: forall typ.
Eq typ =>
InferenceContext typ -> InferenceContext typ -> Bool
/= :: InferenceContext typ -> InferenceContext typ -> Bool
Eq, (forall x. InferenceContext typ -> Rep (InferenceContext typ) x)
-> (forall x. Rep (InferenceContext typ) x -> InferenceContext typ)
-> Generic (InferenceContext typ)
forall x. Rep (InferenceContext typ) x -> InferenceContext typ
forall x. InferenceContext typ -> Rep (InferenceContext typ) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall typ x. Rep (InferenceContext typ) x -> InferenceContext typ
forall typ x. InferenceContext typ -> Rep (InferenceContext typ) x
$cfrom :: forall typ x. InferenceContext typ -> Rep (InferenceContext typ) x
from :: forall x. InferenceContext typ -> Rep (InferenceContext typ) x
$cto :: forall typ x. Rep (InferenceContext typ) x -> InferenceContext typ
to :: forall x. Rep (InferenceContext typ) x -> InferenceContext typ
Generic, Eq (InferenceContext typ)
Eq (InferenceContext typ) =>
(InferenceContext typ -> InferenceContext typ -> Ordering)
-> (InferenceContext typ -> InferenceContext typ -> Bool)
-> (InferenceContext typ -> InferenceContext typ -> Bool)
-> (InferenceContext typ -> InferenceContext typ -> Bool)
-> (InferenceContext typ -> InferenceContext typ -> Bool)
-> (InferenceContext typ
    -> InferenceContext typ -> InferenceContext typ)
-> (InferenceContext typ
    -> InferenceContext typ -> InferenceContext typ)
-> Ord (InferenceContext typ)
InferenceContext typ -> InferenceContext typ -> Bool
InferenceContext typ -> InferenceContext typ -> Ordering
InferenceContext typ
-> InferenceContext typ -> InferenceContext typ
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall typ. Ord typ => Eq (InferenceContext typ)
forall typ.
Ord typ =>
InferenceContext typ -> InferenceContext typ -> Bool
forall typ.
Ord typ =>
InferenceContext typ -> InferenceContext typ -> Ordering
forall typ.
Ord typ =>
InferenceContext typ
-> InferenceContext typ -> InferenceContext typ
$ccompare :: forall typ.
Ord typ =>
InferenceContext typ -> InferenceContext typ -> Ordering
compare :: InferenceContext typ -> InferenceContext typ -> Ordering
$c< :: forall typ.
Ord typ =>
InferenceContext typ -> InferenceContext typ -> Bool
< :: InferenceContext typ -> InferenceContext typ -> Bool
$c<= :: forall typ.
Ord typ =>
InferenceContext typ -> InferenceContext typ -> Bool
<= :: InferenceContext typ -> InferenceContext typ -> Bool
$c> :: forall typ.
Ord typ =>
InferenceContext typ -> InferenceContext typ -> Bool
> :: InferenceContext typ -> InferenceContext typ -> Bool
$c>= :: forall typ.
Ord typ =>
InferenceContext typ -> InferenceContext typ -> Bool
>= :: InferenceContext typ -> InferenceContext typ -> Bool
$cmax :: forall typ.
Ord typ =>
InferenceContext typ
-> InferenceContext typ -> InferenceContext typ
max :: InferenceContext typ
-> InferenceContext typ -> InferenceContext typ
$cmin :: forall typ.
Ord typ =>
InferenceContext typ
-> InferenceContext typ -> InferenceContext typ
min :: InferenceContext typ
-> InferenceContext typ -> InferenceContext typ
Ord, Int -> InferenceContext typ -> ShowS
[InferenceContext typ] -> ShowS
InferenceContext typ -> String
(Int -> InferenceContext typ -> ShowS)
-> (InferenceContext typ -> String)
-> ([InferenceContext typ] -> ShowS)
-> Show (InferenceContext typ)
forall typ. Show typ => Int -> InferenceContext typ -> ShowS
forall typ. Show typ => [InferenceContext typ] -> ShowS
forall typ. Show typ => InferenceContext typ -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall typ. Show typ => Int -> InferenceContext typ -> ShowS
showsPrec :: Int -> InferenceContext typ -> ShowS
$cshow :: forall typ. Show typ => InferenceContext typ -> String
show :: InferenceContext typ -> String
$cshowList :: forall typ. Show typ => [InferenceContext typ] -> ShowS
showList :: [InferenceContext typ] -> ShowS
Show)

-- | A stack of inference contexts, with the most recent context at the head.
newtype ContextStack typ = ContextStack [InferenceContext typ]
    deriving (ContextStack typ -> ContextStack typ -> Bool
(ContextStack typ -> ContextStack typ -> Bool)
-> (ContextStack typ -> ContextStack typ -> Bool)
-> Eq (ContextStack typ)
forall typ. Eq typ => ContextStack typ -> ContextStack typ -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall typ. Eq typ => ContextStack typ -> ContextStack typ -> Bool
== :: ContextStack typ -> ContextStack typ -> Bool
$c/= :: forall typ. Eq typ => ContextStack typ -> ContextStack typ -> Bool
/= :: ContextStack typ -> ContextStack typ -> Bool
Eq, (forall x. ContextStack typ -> Rep (ContextStack typ) x)
-> (forall x. Rep (ContextStack typ) x -> ContextStack typ)
-> Generic (ContextStack typ)
forall x. Rep (ContextStack typ) x -> ContextStack typ
forall x. ContextStack typ -> Rep (ContextStack typ) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall typ x. Rep (ContextStack typ) x -> ContextStack typ
forall typ x. ContextStack typ -> Rep (ContextStack typ) x
$cfrom :: forall typ x. ContextStack typ -> Rep (ContextStack typ) x
from :: forall x. ContextStack typ -> Rep (ContextStack typ) x
$cto :: forall typ x. Rep (ContextStack typ) x -> ContextStack typ
to :: forall x. Rep (ContextStack typ) x -> ContextStack typ
Generic, Int -> ContextStack typ -> ShowS
[ContextStack typ] -> ShowS
ContextStack typ -> String
(Int -> ContextStack typ -> ShowS)
-> (ContextStack typ -> String)
-> ([ContextStack typ] -> ShowS)
-> Show (ContextStack typ)
forall typ. Show typ => Int -> ContextStack typ -> ShowS
forall typ. Show typ => [ContextStack typ] -> ShowS
forall typ. Show typ => ContextStack typ -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall typ. Show typ => Int -> ContextStack typ -> ShowS
showsPrec :: Int -> ContextStack typ -> ShowS
$cshow :: forall typ. Show typ => ContextStack typ -> String
show :: ContextStack typ -> String
$cshowList :: forall typ. Show typ => [ContextStack typ] -> ShowS
showList :: [ContextStack typ] -> ShowS
Show)
    deriving newtype (Semigroup (ContextStack typ)
ContextStack typ
Semigroup (ContextStack typ) =>
ContextStack typ
-> (ContextStack typ -> ContextStack typ -> ContextStack typ)
-> ([ContextStack typ] -> ContextStack typ)
-> Monoid (ContextStack typ)
[ContextStack typ] -> ContextStack typ
ContextStack typ -> ContextStack typ -> ContextStack typ
forall typ. Semigroup (ContextStack typ)
forall typ. ContextStack typ
forall a.
Semigroup a =>
a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall typ. [ContextStack typ] -> ContextStack typ
forall typ.
ContextStack typ -> ContextStack typ -> ContextStack typ
$cmempty :: forall typ. ContextStack typ
mempty :: ContextStack typ
$cmappend :: forall typ.
ContextStack typ -> ContextStack typ -> ContextStack typ
mappend :: ContextStack typ -> ContextStack typ -> ContextStack typ
$cmconcat :: forall typ. [ContextStack typ] -> ContextStack typ
mconcat :: [ContextStack typ] -> ContextStack typ
Monoid, NonEmpty (ContextStack typ) -> ContextStack typ
ContextStack typ -> ContextStack typ -> ContextStack typ
(ContextStack typ -> ContextStack typ -> ContextStack typ)
-> (NonEmpty (ContextStack typ) -> ContextStack typ)
-> (forall b.
    Integral b =>
    b -> ContextStack typ -> ContextStack typ)
-> Semigroup (ContextStack typ)
forall b. Integral b => b -> ContextStack typ -> ContextStack typ
forall typ. NonEmpty (ContextStack typ) -> ContextStack typ
forall typ.
ContextStack typ -> ContextStack typ -> ContextStack typ
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall typ b.
Integral b =>
b -> ContextStack typ -> ContextStack typ
$c<> :: forall typ.
ContextStack typ -> ContextStack typ -> ContextStack typ
<> :: ContextStack typ -> ContextStack typ -> ContextStack typ
$csconcat :: forall typ. NonEmpty (ContextStack typ) -> ContextStack typ
sconcat :: NonEmpty (ContextStack typ) -> ContextStack typ
$cstimes :: forall typ b.
Integral b =>
b -> ContextStack typ -> ContextStack typ
stimes :: forall b. Integral b => b -> ContextStack typ -> ContextStack typ
Semigroup)

-- | Create an empty context stack
emptyContextStack :: ContextStack typ
emptyContextStack :: forall typ. ContextStack typ
emptyContextStack = [InferenceContext typ] -> ContextStack typ
forall typ. [InferenceContext typ] -> ContextStack typ
ContextStack []

-- | Push a context onto the stack
pushContext :: InferenceContext typ -> ContextStack typ -> ContextStack typ
pushContext :: forall typ.
InferenceContext typ -> ContextStack typ -> ContextStack typ
pushContext InferenceContext typ
ctx (ContextStack [InferenceContext typ]
stack) = [InferenceContext typ] -> ContextStack typ
forall typ. [InferenceContext typ] -> ContextStack typ
ContextStack (InferenceContext typ
ctx InferenceContext typ
-> [InferenceContext typ] -> [InferenceContext typ]
forall a. a -> [a] -> [a]
: [InferenceContext typ]
stack)

-- | Get the current (most recent) context, if any
currentContext :: ContextStack typ -> Maybe (InferenceContext typ)
currentContext :: forall typ. ContextStack typ -> Maybe (InferenceContext typ)
currentContext (ContextStack []) = Maybe (InferenceContext typ)
forall a. Maybe a
Nothing
currentContext (ContextStack (InferenceContext typ
x : [InferenceContext typ]
_)) = InferenceContext typ -> Maybe (InferenceContext typ)
forall a. a -> Maybe a
Just InferenceContext typ
x

-- | Get all contexts in the stack (most recent first)
allContexts :: ContextStack typ -> [InferenceContext typ]
allContexts :: forall typ. ContextStack typ -> [InferenceContext typ]
allContexts (ContextStack [InferenceContext typ]
stack) = [InferenceContext typ]
stack

instance Pretty typ => Pretty (InferenceContext typ) where
    pretty :: InferenceContext typ -> Doc AnsiStyle
pretty = \case
        CheckingFunctionArgument Int
pos (Just Qualified VarName
fnName) typ
_ typ
_ SourceRegion
_ ->
            Doc AnsiStyle
"while checking argument" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Int -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Int
pos Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle
"of" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
squotes (Qualified VarName -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Qualified VarName
fnName)
        CheckingFunctionArgument Int
pos Maybe (Qualified VarName)
Nothing typ
_ typ
_ SourceRegion
_ ->
            Doc AnsiStyle
"while checking argument" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Int -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Int
pos
        CheckingFunctionResult SourceRegion
_ ->
            Doc AnsiStyle
"while checking function result"
        CheckingIfCondition SourceRegion
_ ->
            Doc AnsiStyle
"while checking if condition"
        CheckingIfBranches SourceRegion
_ SourceRegion
_ ->
            Doc AnsiStyle
"while checking that if branches have the same type"
        CheckingMatchBranch Int
idx SourceRegion
_ ->
            Doc AnsiStyle
"while checking match branch" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Int -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Int
idx
        CheckingPattern SourceRegion
_ ->
            Doc AnsiStyle
"while checking pattern"
        CheckingLetBinding Qualified VarName
name SourceRegion
_ ->
            Doc AnsiStyle
"while checking let binding" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
squotes (Qualified VarName -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Qualified VarName
name)
        CheckingAnnotation SourceRegion
_ ->
            Doc AnsiStyle
"while checking type annotation"

instance Pretty typ => Pretty (ContextStack typ) where
    pretty :: ContextStack typ -> Doc AnsiStyle
pretty (ContextStack []) = Doc AnsiStyle
forall a. Monoid a => a
mempty
    pretty (ContextStack [InferenceContext typ]
ctxs) = [Doc AnsiStyle] -> Doc AnsiStyle
forall ann. [Doc ann] -> Doc ann
vsep (InferenceContext typ -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty (InferenceContext typ -> Doc AnsiStyle)
-> [InferenceContext typ] -> [Doc AnsiStyle]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [InferenceContext typ]
ctxs)