module Elara.TypeInfer.Generalise where

import Data.Generics.Sum (AsAny (_As))
import Data.Set (difference, member)
import Elara.AST.Region
import Elara.Logging
import Elara.TypeInfer.Environment
import Elara.TypeInfer.Ftv
import Elara.TypeInfer.Type

import Effectful
import Effectful.State.Static.Local (State, get)

generalise :: forall r. (StructuredDebug :> r, State (TypeEnvironment SourceRegion) :> r) => Monotype SourceRegion -> Eff r (Polytype SourceRegion)
generalise :: forall (r :: [Effect]).
(StructuredDebug :> r,
 State (TypeEnvironment SourceRegion) :> r) =>
Monotype SourceRegion -> Eff r (Polytype SourceRegion)
generalise Monotype SourceRegion
ty = do
    env <- forall s (es :: [Effect]).
(HasCallStack, State s :> es) =>
Eff es s
get @(TypeEnvironment SourceRegion)
    let freeVars = Monotype SourceRegion -> Set TypeVariable
forall a. Ftv a => a -> Set TypeVariable
ftv Monotype SourceRegion
ty
    let envVars = Set TypeVariable
freeVars Set TypeVariable -> Set TypeVariable -> Set TypeVariable
forall a. Ord a => Set a -> Set a -> Set a
`difference` TypeEnvironment SourceRegion -> Set TypeVariable
forall a. Ftv a => a -> Set TypeVariable
ftv TypeEnvironment SourceRegion
env
    let uniVars = Set TypeVariable
envVars Set TypeVariable
-> Optic' A_Fold NoIx (Set TypeVariable) UniqueTyVar
-> [UniqueTyVar]
forall k s (is :: IxList) a.
Is k A_Fold =>
s -> Optic' k is s a -> [a]
^.. Fold (Set TypeVariable) TypeVariable
forall (f :: * -> *) a. Foldable f => Fold (f a) a
folded Fold (Set TypeVariable) TypeVariable
-> Optic
     A_Prism NoIx TypeVariable TypeVariable UniqueTyVar UniqueTyVar
-> Optic' A_Fold NoIx (Set TypeVariable) UniqueTyVar
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% (forall {k} (sel :: k) a s. AsAny sel a s => Prism s s a a
forall (sel :: Symbol) a s. AsAny sel a s => Prism s s a a
_As @"UnificationVar")

    let generalized = SourceRegion
-> [UniqueTyVar]
-> Constraint SourceRegion
-> Monotype SourceRegion
-> Polytype SourceRegion
forall loc.
loc
-> [UniqueTyVar] -> Constraint loc -> Monotype loc -> Polytype loc
Forall (Monotype SourceRegion -> SourceRegion
forall loc. Monotype loc -> loc
monotypeLoc Monotype SourceRegion
ty) ([UniqueTyVar] -> [UniqueTyVar]
forall a. [a] -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList [UniqueTyVar]
uniVars) (SourceRegion -> Constraint SourceRegion
forall loc. loc -> Constraint loc
EmptyConstraint (SourceRegion -> Constraint SourceRegion)
-> SourceRegion -> Constraint SourceRegion
forall a b. (a -> b) -> a -> b
$ Monotype SourceRegion -> SourceRegion
forall loc. Monotype loc -> loc
monotypeLoc Monotype SourceRegion
ty) Monotype SourceRegion
ty
    pure generalized

removeSkolems :: Generic loc => Monotype loc -> Monotype loc
removeSkolems :: forall loc. Generic loc => Monotype loc -> Monotype loc
removeSkolems Monotype loc
ty = do
    let ftvs :: Set TypeVariable
ftvs = Monotype loc -> Set TypeVariable
forall a. Ftv a => a -> Set TypeVariable
ftv Monotype loc
ty

    Optic
  A_Traversal
  NoIx
  (Monotype loc)
  (Monotype loc)
  (Monotype loc)
  (Monotype loc)
-> (Monotype loc -> Monotype loc) -> Monotype loc -> Monotype loc
forall k (is :: IxList) a b.
Is k A_Setter =>
Optic k is a b a b -> (b -> b) -> a -> b
transformOf
        Optic
  A_Traversal
  NoIx
  (Monotype loc)
  (Monotype loc)
  (Monotype loc)
  (Monotype loc)
forall a s. Plated a s => Traversal' s a
plate
        ( \case
            TypeVar loc
loc tv :: TypeVariable
tv@(SkolemVar UniqueTyVar
tv') | TypeVariable
tv TypeVariable -> Set TypeVariable -> Bool
forall a. Ord a => a -> Set a -> Bool
`member` Set TypeVariable
ftvs -> loc -> TypeVariable -> Monotype loc
forall loc. loc -> TypeVariable -> Monotype loc
TypeVar loc
loc (UniqueTyVar -> TypeVariable
UnificationVar UniqueTyVar
tv')
            Monotype loc
other -> Monotype loc
other
        )
        Monotype loc
ty