-- | Inference context tracking for better error messages
module Elara.TypeInfer.Context where

import Elara.AST.Name (Qualified, VarName)
import Elara.AST.Region (SourceRegion)
import Elara.Data.Pretty

{- | Describes the context in which type inference is happening.
This is used to provide better error messages by explaining
WHY two types are being compared.
-}
data InferenceContext
    = -- | Checking an argument to a function call
      CheckingFunctionArgument
        { InferenceContext -> Int
argPosition :: !Int
        , InferenceContext -> Maybe (Qualified VarName)
functionName :: !(Maybe (Qualified VarName))
        , InferenceContext -> SourceRegion
callSite :: !SourceRegion
        }
    | -- | Checking the result type of a function call
      CheckingFunctionResult
        { callSite :: !SourceRegion
        }
    | -- | Checking the condition of an if expression
      CheckingIfCondition
        { InferenceContext -> SourceRegion
ifSite :: !SourceRegion
        }
    | -- | Checking that if branches have the same type
      CheckingIfBranches
        { InferenceContext -> SourceRegion
thenSite :: !SourceRegion
        , InferenceContext -> SourceRegion
elseSite :: !SourceRegion
        }
    | -- | Checking a branch in a match expression
      CheckingMatchBranch
        { InferenceContext -> Int
branchIndex :: !Int
        , InferenceContext -> SourceRegion
branchSite :: !SourceRegion
        }
    | -- | Checking a pattern
      CheckingPattern
        { InferenceContext -> SourceRegion
patternSite :: !SourceRegion
        }
    | -- | Checking a let binding
      CheckingLetBinding
        { InferenceContext -> Qualified VarName
bindingName :: !(Qualified VarName)
        , InferenceContext -> SourceRegion
bindingSite :: !SourceRegion
        }
    | -- | Checking against a type annotation
      CheckingAnnotation
        { InferenceContext -> SourceRegion
annotationSite :: !SourceRegion
        }
    deriving ((forall x. InferenceContext -> Rep InferenceContext x)
-> (forall x. Rep InferenceContext x -> InferenceContext)
-> Generic InferenceContext
forall x. Rep InferenceContext x -> InferenceContext
forall x. InferenceContext -> Rep InferenceContext x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. InferenceContext -> Rep InferenceContext x
from :: forall x. InferenceContext -> Rep InferenceContext x
$cto :: forall x. Rep InferenceContext x -> InferenceContext
to :: forall x. Rep InferenceContext x -> InferenceContext
Generic, Int -> InferenceContext -> ShowS
[InferenceContext] -> ShowS
InferenceContext -> String
(Int -> InferenceContext -> ShowS)
-> (InferenceContext -> String)
-> ([InferenceContext] -> ShowS)
-> Show InferenceContext
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> InferenceContext -> ShowS
showsPrec :: Int -> InferenceContext -> ShowS
$cshow :: InferenceContext -> String
show :: InferenceContext -> String
$cshowList :: [InferenceContext] -> ShowS
showList :: [InferenceContext] -> ShowS
Show, InferenceContext -> InferenceContext -> Bool
(InferenceContext -> InferenceContext -> Bool)
-> (InferenceContext -> InferenceContext -> Bool)
-> Eq InferenceContext
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: InferenceContext -> InferenceContext -> Bool
== :: InferenceContext -> InferenceContext -> Bool
$c/= :: InferenceContext -> InferenceContext -> Bool
/= :: InferenceContext -> InferenceContext -> Bool
Eq, Eq InferenceContext
Eq InferenceContext =>
(InferenceContext -> InferenceContext -> Ordering)
-> (InferenceContext -> InferenceContext -> Bool)
-> (InferenceContext -> InferenceContext -> Bool)
-> (InferenceContext -> InferenceContext -> Bool)
-> (InferenceContext -> InferenceContext -> Bool)
-> (InferenceContext -> InferenceContext -> InferenceContext)
-> (InferenceContext -> InferenceContext -> InferenceContext)
-> Ord InferenceContext
InferenceContext -> InferenceContext -> Bool
InferenceContext -> InferenceContext -> Ordering
InferenceContext -> InferenceContext -> InferenceContext
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: InferenceContext -> InferenceContext -> Ordering
compare :: InferenceContext -> InferenceContext -> Ordering
$c< :: InferenceContext -> InferenceContext -> Bool
< :: InferenceContext -> InferenceContext -> Bool
$c<= :: InferenceContext -> InferenceContext -> Bool
<= :: InferenceContext -> InferenceContext -> Bool
$c> :: InferenceContext -> InferenceContext -> Bool
> :: InferenceContext -> InferenceContext -> Bool
$c>= :: InferenceContext -> InferenceContext -> Bool
>= :: InferenceContext -> InferenceContext -> Bool
$cmax :: InferenceContext -> InferenceContext -> InferenceContext
max :: InferenceContext -> InferenceContext -> InferenceContext
$cmin :: InferenceContext -> InferenceContext -> InferenceContext
min :: InferenceContext -> InferenceContext -> InferenceContext
Ord)

-- | A stack of inference contexts, with the most recent context at the head.
newtype ContextStack = ContextStack [InferenceContext]
    deriving ((forall x. ContextStack -> Rep ContextStack x)
-> (forall x. Rep ContextStack x -> ContextStack)
-> Generic ContextStack
forall x. Rep ContextStack x -> ContextStack
forall x. ContextStack -> Rep ContextStack x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ContextStack -> Rep ContextStack x
from :: forall x. ContextStack -> Rep ContextStack x
$cto :: forall x. Rep ContextStack x -> ContextStack
to :: forall x. Rep ContextStack x -> ContextStack
Generic, Int -> ContextStack -> ShowS
[ContextStack] -> ShowS
ContextStack -> String
(Int -> ContextStack -> ShowS)
-> (ContextStack -> String)
-> ([ContextStack] -> ShowS)
-> Show ContextStack
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ContextStack -> ShowS
showsPrec :: Int -> ContextStack -> ShowS
$cshow :: ContextStack -> String
show :: ContextStack -> String
$cshowList :: [ContextStack] -> ShowS
showList :: [ContextStack] -> ShowS
Show, ContextStack -> ContextStack -> Bool
(ContextStack -> ContextStack -> Bool)
-> (ContextStack -> ContextStack -> Bool) -> Eq ContextStack
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ContextStack -> ContextStack -> Bool
== :: ContextStack -> ContextStack -> Bool
$c/= :: ContextStack -> ContextStack -> Bool
/= :: ContextStack -> ContextStack -> Bool
Eq)
    deriving newtype (NonEmpty ContextStack -> ContextStack
ContextStack -> ContextStack -> ContextStack
(ContextStack -> ContextStack -> ContextStack)
-> (NonEmpty ContextStack -> ContextStack)
-> (forall b. Integral b => b -> ContextStack -> ContextStack)
-> Semigroup ContextStack
forall b. Integral b => b -> ContextStack -> ContextStack
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
$c<> :: ContextStack -> ContextStack -> ContextStack
<> :: ContextStack -> ContextStack -> ContextStack
$csconcat :: NonEmpty ContextStack -> ContextStack
sconcat :: NonEmpty ContextStack -> ContextStack
$cstimes :: forall b. Integral b => b -> ContextStack -> ContextStack
stimes :: forall b. Integral b => b -> ContextStack -> ContextStack
Semigroup, Semigroup ContextStack
ContextStack
Semigroup ContextStack =>
ContextStack
-> (ContextStack -> ContextStack -> ContextStack)
-> ([ContextStack] -> ContextStack)
-> Monoid ContextStack
[ContextStack] -> ContextStack
ContextStack -> ContextStack -> ContextStack
forall a.
Semigroup a =>
a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
$cmempty :: ContextStack
mempty :: ContextStack
$cmappend :: ContextStack -> ContextStack -> ContextStack
mappend :: ContextStack -> ContextStack -> ContextStack
$cmconcat :: [ContextStack] -> ContextStack
mconcat :: [ContextStack] -> ContextStack
Monoid)

-- | Create an empty context stack
emptyContextStack :: ContextStack
emptyContextStack :: ContextStack
emptyContextStack = [InferenceContext] -> ContextStack
ContextStack []

-- | Push a context onto the stack
pushContext :: InferenceContext -> ContextStack -> ContextStack
pushContext :: InferenceContext -> ContextStack -> ContextStack
pushContext InferenceContext
ctx (ContextStack [InferenceContext]
stack) = [InferenceContext] -> ContextStack
ContextStack (InferenceContext
ctx InferenceContext -> [InferenceContext] -> [InferenceContext]
forall a. a -> [a] -> [a]
: [InferenceContext]
stack)

-- | Get the current (most recent) context, if any
currentContext :: ContextStack -> Maybe InferenceContext
currentContext :: ContextStack -> Maybe InferenceContext
currentContext (ContextStack []) = Maybe InferenceContext
forall a. Maybe a
Nothing
currentContext (ContextStack (InferenceContext
x : [InferenceContext]
_)) = InferenceContext -> Maybe InferenceContext
forall a. a -> Maybe a
Just InferenceContext
x

-- | Get all contexts in the stack (most recent first)
allContexts :: ContextStack -> [InferenceContext]
allContexts :: ContextStack -> [InferenceContext]
allContexts (ContextStack [InferenceContext]
stack) = [InferenceContext]
stack

instance Pretty InferenceContext where
    pretty :: InferenceContext -> Doc AnsiStyle
pretty = \case
        CheckingFunctionArgument Int
pos Maybe (Qualified VarName)
mFn SourceRegion
_ ->
            Doc AnsiStyle
"while checking argument" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Int -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Int
pos Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle
-> (Qualified VarName -> Doc AnsiStyle)
-> Maybe (Qualified VarName)
-> Doc AnsiStyle
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc AnsiStyle
forall a. Monoid a => a
mempty (\Qualified VarName
fn -> Doc AnsiStyle
"of" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
squotes (Qualified VarName -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Qualified VarName
fn)) Maybe (Qualified VarName)
mFn
        CheckingFunctionResult SourceRegion
_ ->
            Doc AnsiStyle
"while checking function result"
        CheckingIfCondition SourceRegion
_ ->
            Doc AnsiStyle
"while checking if condition"
        CheckingIfBranches SourceRegion
_ SourceRegion
_ ->
            Doc AnsiStyle
"while checking that if branches have the same type"
        CheckingMatchBranch Int
idx SourceRegion
_ ->
            Doc AnsiStyle
"while checking match branch" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Int -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Int
idx
        CheckingPattern SourceRegion
_ ->
            Doc AnsiStyle
"while checking pattern"
        CheckingLetBinding Qualified VarName
name SourceRegion
_ ->
            Doc AnsiStyle
"while checking let binding" Doc AnsiStyle -> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc AnsiStyle -> Doc AnsiStyle
forall ann. Doc ann -> Doc ann
squotes (Qualified VarName -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty Qualified VarName
name)
        CheckingAnnotation SourceRegion
_ ->
            Doc AnsiStyle
"while checking type annotation"

instance Pretty ContextStack where
    pretty :: ContextStack -> Doc AnsiStyle
pretty (ContextStack []) = Doc AnsiStyle
forall a. Monoid a => a
mempty
    pretty (ContextStack [InferenceContext]
ctxs) = [Doc AnsiStyle] -> Doc AnsiStyle
forall ann. [Doc ann] -> Doc ann
vsep (InferenceContext -> Doc AnsiStyle
forall a. Pretty a => a -> Doc AnsiStyle
pretty (InferenceContext -> Doc AnsiStyle)
-> [InferenceContext] -> [Doc AnsiStyle]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [InferenceContext]
ctxs)