{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
module Clash.Normalize.Util
( ConstantSpecInfo(..)
, isConstantArg
, shouldReduce
, alreadyInlined
, addNewInline
, isRecursiveBndr
, callGraph
, collectCallGraphUniques
, classifyFunction
, isCheapFunction
, isNonRecursiveGlobalVar
, constantSpecInfo
, normalizeTopLvlBndr
, rewriteExpr
, mkInlineTick
, substWithTyEq
, tvSubstWithTyEq
)
where
import Control.Lens ((&),(+~),(%=),(.=))
import qualified Control.Lens as Lens
import Data.Bifunctor (bimap)
import Data.Either (lefts,rights)
import qualified Data.List as List
import qualified Data.List.Extra as List
import qualified Data.Map as Map
import qualified Data.HashMap.Strict as HashMapS
import qualified Data.HashSet as HashSet
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Extra as Text
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Names (eqTyConKey)
import GHC.Types.Unique (getKey)
#else
import PrelNames (eqTyConKey)
import Unique (getKey)
#endif
import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.FreeVars
(globalIds, globalIdOccursIn)
import Clash.Core.HasFreeVars (isClosed)
import Clash.Core.HasType
import Clash.Core.Name (Name(nameOcc,nameUniq))
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
(deShadowTerm, extendTvSubst, mkSubst, substTm, substTy,
substId, extendIdSubst)
import Clash.Core.Term
import Clash.Core.Type
(Type(ForAllTy,LitTy, VarTy), LitTy(SymTy), TypeView (..), tyView,
splitTyConAppM, mkPolyFunTy)
import Clash.Core.Util
(isClockOrReset)
import Clash.Core.Var (Id, TyVar, Var (..), isGlobalId)
import Clash.Core.VarEnv
(VarEnv, emptyInScopeSet, emptyVarEnv, extendVarEnv, extendVarEnvWith,
lookupVarEnv, unionVarEnvWith, unitVarEnv, extendInScopeSetList, mkInScopeSet, mkVarSet)
import qualified Clash.Data.UniqMap as UniqMap
import Clash.Debug (traceIf)
import Clash.Driver.Types
(BindingMap, Binding(..), TransformationInfo(FinalTerm), hasTransformationInfo)
import Clash.Normalize.Primitives (removedArg)
import {-# SOURCE #-} Clash.Normalize.Strategy (normalization)
import Clash.Normalize.Types
import Clash.Primitives.Util (constantArgs)
import Clash.Rewrite.Types
(RewriteMonad, TransformContext(..), bindings, curFun, debugOpts, extra,
tcCache, primitives)
import Clash.Rewrite.Util
(runRewrite, mkTmBinderFor, mkDerivedName)
import Clash.Unique
import Clash.Util (SrcSpan, makeCachedU)
isConstantArg
:: Text
-> Int
-> RewriteMonad NormalizeState Bool
isConstantArg :: OccName -> Unique -> RewriteMonad NormalizeState Bool
isConstantArg OccName
"Clash.Explicit.SimIO.mealyIO" Unique
i = Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Unique
i Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique
2 Bool -> Bool -> Bool
|| Unique
i Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique
3)
isConstantArg OccName
nm Unique
i = do
Map OccName (Set Unique)
argMap <- Getting
(Map OccName (Set Unique))
(RewriteState NormalizeState)
(Map OccName (Set Unique))
-> RewriteMonad NormalizeState (Map OccName (Set Unique))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((NormalizeState -> Const (Map OccName (Set Unique)) NormalizeState)
-> RewriteState NormalizeState
-> Const (Map OccName (Set Unique)) (RewriteState NormalizeState)
forall extra1 extra2 (f :: Type -> Type).
Functor f =>
(extra1 -> f extra2)
-> RewriteState extra1 -> f (RewriteState extra2)
extra((NormalizeState
-> Const (Map OccName (Set Unique)) NormalizeState)
-> RewriteState NormalizeState
-> Const (Map OccName (Set Unique)) (RewriteState NormalizeState))
-> ((Map OccName (Set Unique)
-> Const (Map OccName (Set Unique)) (Map OccName (Set Unique)))
-> NormalizeState
-> Const (Map OccName (Set Unique)) NormalizeState)
-> Getting
(Map OccName (Set Unique))
(RewriteState NormalizeState)
(Map OccName (Set Unique))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map OccName (Set Unique)
-> Const (Map OccName (Set Unique)) (Map OccName (Set Unique)))
-> NormalizeState
-> Const (Map OccName (Set Unique)) NormalizeState
Lens' NormalizeState (Map OccName (Set Unique))
primitiveArgs)
case OccName -> Map OccName (Set Unique) -> Maybe (Set Unique)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup OccName
nm Map OccName (Set Unique)
argMap of
Maybe (Set Unique)
Nothing -> do
CompiledPrimMap
prims <- Getting CompiledPrimMap RewriteEnv CompiledPrimMap
-> RewriteMonad NormalizeState CompiledPrimMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting CompiledPrimMap RewriteEnv CompiledPrimMap
Getter RewriteEnv CompiledPrimMap
primitives
case PrimitiveGuard CompiledPrimitive -> Maybe CompiledPrimitive
forall a. PrimitiveGuard a -> Maybe a
extractPrim (PrimitiveGuard CompiledPrimitive -> Maybe CompiledPrimitive)
-> Maybe (PrimitiveGuard CompiledPrimitive)
-> Maybe CompiledPrimitive
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< OccName
-> CompiledPrimMap -> Maybe (PrimitiveGuard CompiledPrimitive)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HashMapS.lookup OccName
nm CompiledPrimMap
prims of
Maybe CompiledPrimitive
Nothing ->
Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
Just CompiledPrimitive
p -> do
let m :: Set Unique
m = OccName -> CompiledPrimitive -> Set Unique
constantArgs OccName
nm CompiledPrimitive
p
((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1 extra2 (f :: Type -> Type).
Functor f =>
(extra1 -> f extra2)
-> RewriteState extra1 -> f (RewriteState extra2)
extra((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> ((Map OccName (Set Unique)
-> Identity (Map OccName (Set Unique)))
-> NormalizeState -> Identity NormalizeState)
-> (Map OccName (Set Unique)
-> Identity (Map OccName (Set Unique)))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map OccName (Set Unique) -> Identity (Map OccName (Set Unique)))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (Map OccName (Set Unique))
primitiveArgs) ((Map OccName (Set Unique) -> Identity (Map OccName (Set Unique)))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> (Map OccName (Set Unique) -> Map OccName (Set Unique))
-> RewriteMonad NormalizeState ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
Lens.%= OccName
-> Set Unique
-> Map OccName (Set Unique)
-> Map OccName (Set Unique)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert OccName
nm Set Unique
m
Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Unique
i Unique -> Set Unique -> Bool
forall a. Eq a => a -> Set a -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` Set Unique
m)
Just Set Unique
m ->
Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Unique
i Unique -> Set Unique -> Bool
forall a. Eq a => a -> Set a -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` Set Unique
m)
shouldReduce
:: Context
-> RewriteMonad NormalizeState Bool
shouldReduce :: Context -> RewriteMonad NormalizeState Bool
shouldReduce = (CoreContext -> RewriteMonad NormalizeState Bool)
-> Context -> RewriteMonad NormalizeState Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
List.anyM CoreContext -> RewriteMonad NormalizeState Bool
isConstantArg'
where
isConstantArg' :: CoreContext -> RewriteMonad NormalizeState Bool
isConstantArg' (AppArg (Just (OccName
nm, Unique
_, Unique
i))) = OccName -> Unique -> RewriteMonad NormalizeState Bool
isConstantArg OccName
nm Unique
i
isConstantArg' CoreContext
_ = Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
alreadyInlined
:: Id
-> Id
-> NormalizeMonad (Maybe Int)
alreadyInlined :: Id -> Id -> NormalizeMonad (Maybe Unique)
alreadyInlined Id
f Id
cf = do
VarEnv (VarEnv Unique)
inlinedHM <- Getting
(VarEnv (VarEnv Unique)) NormalizeState (VarEnv (VarEnv Unique))
-> StateT NormalizeState Identity (VarEnv (VarEnv Unique))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
(VarEnv (VarEnv Unique)) NormalizeState (VarEnv (VarEnv Unique))
Lens' NormalizeState (VarEnv (VarEnv Unique))
inlineHistory
case Id -> VarEnv (VarEnv Unique) -> Maybe (VarEnv Unique)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
cf VarEnv (VarEnv Unique)
inlinedHM of
Maybe (VarEnv Unique)
Nothing -> Maybe Unique -> NormalizeMonad (Maybe Unique)
forall a. a -> StateT NormalizeState Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Maybe Unique
forall a. Maybe a
Nothing
Just VarEnv Unique
inlined' -> Maybe Unique -> NormalizeMonad (Maybe Unique)
forall a. a -> StateT NormalizeState Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id -> VarEnv Unique -> Maybe Unique
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv Unique
inlined')
addNewInline
:: Id
-> Id
-> NormalizeMonad ()
addNewInline :: Id -> Id -> NormalizeMonad ()
addNewInline Id
f Id
cf =
(VarEnv (VarEnv Unique) -> Identity (VarEnv (VarEnv Unique)))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (VarEnv (VarEnv Unique))
inlineHistory ((VarEnv (VarEnv Unique) -> Identity (VarEnv (VarEnv Unique)))
-> NormalizeState -> Identity NormalizeState)
-> (VarEnv (VarEnv Unique) -> VarEnv (VarEnv Unique))
-> NormalizeMonad ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Id
-> VarEnv Unique
-> (VarEnv Unique -> VarEnv Unique -> VarEnv Unique)
-> VarEnv (VarEnv Unique)
-> VarEnv (VarEnv Unique)
forall b a. Var b -> a -> (a -> a -> a) -> VarEnv a -> VarEnv a
extendVarEnvWith
Id
cf
(Id -> Unique -> VarEnv Unique
forall b a. Var b -> a -> VarEnv a
unitVarEnv Id
f Unique
1)
(\VarEnv Unique
_ VarEnv Unique
hm -> Id
-> Unique
-> (Unique -> Unique -> Unique)
-> VarEnv Unique
-> VarEnv Unique
forall b a. Var b -> a -> (a -> a -> a) -> VarEnv a -> VarEnv a
extendVarEnvWith Id
f Unique
1 Unique -> Unique -> Unique
forall a. Num a => a -> a -> a
(+) VarEnv Unique
hm)
isNonRecursiveGlobalVar
:: Term
-> NormalizeSession Bool
isNonRecursiveGlobalVar :: Term -> RewriteMonad NormalizeState Bool
isNonRecursiveGlobalVar (Term -> (Term, [Either Term Type])
collectArgs -> (Var Id
i, [Either Term Type]
_args)) = do
let eIsGlobal :: Bool
eIsGlobal = Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
i
Bool
eIsRec <- Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr Id
i
Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Bool
eIsGlobal Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
eIsRec)
isNonRecursiveGlobalVar Term
_ = Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False
isRecursiveBndr
:: Id
-> NormalizeSession Bool
isRecursiveBndr :: Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr Id
f = do
VarEnv Bool
cg <- Getting (VarEnv Bool) (RewriteState NormalizeState) (VarEnv Bool)
-> RewriteMonad NormalizeState (VarEnv Bool)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((NormalizeState -> Const (VarEnv Bool) NormalizeState)
-> RewriteState NormalizeState
-> Const (VarEnv Bool) (RewriteState NormalizeState)
forall extra1 extra2 (f :: Type -> Type).
Functor f =>
(extra1 -> f extra2)
-> RewriteState extra1 -> f (RewriteState extra2)
extra((NormalizeState -> Const (VarEnv Bool) NormalizeState)
-> RewriteState NormalizeState
-> Const (VarEnv Bool) (RewriteState NormalizeState))
-> ((VarEnv Bool -> Const (VarEnv Bool) (VarEnv Bool))
-> NormalizeState -> Const (VarEnv Bool) NormalizeState)
-> Getting
(VarEnv Bool) (RewriteState NormalizeState) (VarEnv Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Bool -> Const (VarEnv Bool) (VarEnv Bool))
-> NormalizeState -> Const (VarEnv Bool) NormalizeState
Lens' NormalizeState (VarEnv Bool)
recursiveComponents)
case Id -> VarEnv Bool -> Maybe Bool
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv Bool
cg of
Just Bool
isR -> Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
isR
Maybe Bool
Nothing -> do
Maybe (Binding Term)
fBodyM <- Id -> VarEnv (Binding Term) -> Maybe (Binding Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f (VarEnv (Binding Term) -> Maybe (Binding Term))
-> RewriteMonad NormalizeState (VarEnv (Binding Term))
-> RewriteMonad NormalizeState (Maybe (Binding Term))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
(VarEnv (Binding Term))
(RewriteState NormalizeState)
(VarEnv (Binding Term))
-> RewriteMonad NormalizeState (VarEnv (Binding Term))
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting
(VarEnv (Binding Term))
(RewriteState NormalizeState)
(VarEnv (Binding Term))
forall extra (f :: Type -> Type).
Functor f =>
(VarEnv (Binding Term) -> f (VarEnv (Binding Term)))
-> RewriteState extra -> f (RewriteState extra)
bindings
case Maybe (Binding Term)
fBodyM of
Maybe (Binding Term)
Nothing -> Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
False
Just Binding Term
b -> do
let isR :: Bool
isR = Id
f Id -> Term -> Bool
`globalIdOccursIn` Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
b
((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1 extra2 (f :: Type -> Type).
Functor f =>
(extra1 -> f extra2)
-> RewriteState extra1 -> f (RewriteState extra2)
extra((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> ((VarEnv Bool -> Identity (VarEnv Bool))
-> NormalizeState -> Identity NormalizeState)
-> (VarEnv Bool -> Identity (VarEnv Bool))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Bool -> Identity (VarEnv Bool))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (VarEnv Bool)
recursiveComponents) ((VarEnv Bool -> Identity (VarEnv Bool))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> (VarEnv Bool -> VarEnv Bool) -> RewriteMonad NormalizeState ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Id -> Bool -> VarEnv Bool -> VarEnv Bool
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
f Bool
isR
Bool -> RewriteMonad NormalizeState Bool
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Bool
isR
data ConstantSpecInfo =
ConstantSpecInfo
{ ConstantSpecInfo -> [(Id, Term)]
csrNewBindings :: [(Id, Term)]
, ConstantSpecInfo -> Term
csrNewTerm :: !Term
, ConstantSpecInfo -> Bool
csrFoundConstant :: !Bool
} deriving (Unique -> ConstantSpecInfo -> ShowS
[ConstantSpecInfo] -> ShowS
ConstantSpecInfo -> [Char]
(Unique -> ConstantSpecInfo -> ShowS)
-> (ConstantSpecInfo -> [Char])
-> ([ConstantSpecInfo] -> ShowS)
-> Show ConstantSpecInfo
forall a.
(Unique -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Unique -> ConstantSpecInfo -> ShowS
showsPrec :: Unique -> ConstantSpecInfo -> ShowS
$cshow :: ConstantSpecInfo -> [Char]
show :: ConstantSpecInfo -> [Char]
$cshowList :: [ConstantSpecInfo] -> ShowS
showList :: [ConstantSpecInfo] -> ShowS
Show)
constantCsr :: Term -> ConstantSpecInfo
constantCsr :: Term -> ConstantSpecInfo
constantCsr Term
t = [(Id, Term)] -> Term -> Bool -> ConstantSpecInfo
ConstantSpecInfo [] Term
t Bool
True
bindCsr
:: TransformContext
-> Term
-> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr :: TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) Term
oldTerm = do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
Id
newId <- InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is0 TyConMap
tcm (TransformContext -> OccName -> Name Term
mkDerivedName TransformContext
ctx OccName
"bindCsr") Term
oldTerm
ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (ConstantSpecInfo
{ csrNewBindings :: [(Id, Term)]
csrNewBindings = [(Id
newId, Term
oldTerm)]
, csrNewTerm :: Term
csrNewTerm = Id -> Term
Var Id
newId
, csrFoundConstant :: Bool
csrFoundConstant = Bool
False
})
mergeCsrs
:: TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs :: TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs TransformContext
ctx [TickInfo]
ticks Term
oldTerm [Either Term Type] -> Term
proposedTerm [Either Term Type]
subTerms = do
[Either ConstantSpecInfo Type]
subCsrs <- (TransformContext, [Either ConstantSpecInfo Type])
-> [Either ConstantSpecInfo Type]
forall a b. (a, b) -> b
snd ((TransformContext, [Either ConstantSpecInfo Type])
-> [Either ConstantSpecInfo Type])
-> RewriteMonad
NormalizeState (TransformContext, [Either ConstantSpecInfo Type])
-> RewriteMonad NormalizeState [Either ConstantSpecInfo Type]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (TransformContext
-> Either Term Type
-> RewriteMonad
NormalizeState (TransformContext, Either ConstantSpecInfo Type))
-> TransformContext
-> [Either Term Type]
-> RewriteMonad
NormalizeState (TransformContext, [Either ConstantSpecInfo Type])
forall (m :: Type -> Type) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
List.mapAccumLM TransformContext
-> Either Term Type
-> RewriteMonad
NormalizeState (TransformContext, Either ConstantSpecInfo Type)
constantSpecInfoFolder TransformContext
ctx [Either Term Type]
subTerms
let
anyArgsOrResultConstant :: Bool
anyArgsOrResultConstant =
[ConstantSpecInfo] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ([Either ConstantSpecInfo Type] -> [ConstantSpecInfo]
forall a b. [Either a b] -> [a]
lefts [Either ConstantSpecInfo Type]
subCsrs) Bool -> Bool -> Bool
|| (ConstantSpecInfo -> Bool) -> [ConstantSpecInfo] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any ConstantSpecInfo -> Bool
csrFoundConstant ([Either ConstantSpecInfo Type] -> [ConstantSpecInfo]
forall a b. [Either a b] -> [a]
lefts [Either ConstantSpecInfo Type]
subCsrs)
if Bool
anyArgsOrResultConstant then
let newTerm :: Term
newTerm = [Either Term Type] -> Term
proposedTerm ((ConstantSpecInfo -> Term)
-> (Type -> Type)
-> Either ConstantSpecInfo Type
-> Either Term Type
forall a b c d. (a -> b) -> (c -> d) -> Either a c -> Either b d
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ConstantSpecInfo -> Term
csrNewTerm Type -> Type
forall a. a -> a
id (Either ConstantSpecInfo Type -> Either Term Type)
-> [Either ConstantSpecInfo Type] -> [Either Term Type]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Either ConstantSpecInfo Type]
subCsrs) in
ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (ConstantSpecInfo
{ csrNewBindings :: [(Id, Term)]
csrNewBindings = (ConstantSpecInfo -> [(Id, Term)])
-> [ConstantSpecInfo] -> [(Id, Term)]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ([Either ConstantSpecInfo Type] -> [ConstantSpecInfo]
forall a b. [Either a b] -> [a]
lefts [Either ConstantSpecInfo Type]
subCsrs)
, csrNewTerm :: Term
csrNewTerm = Term -> [TickInfo] -> Term
mkTicks Term
newTerm [TickInfo]
ticks
, csrFoundConstant :: Bool
csrFoundConstant = Bool
True
})
else do
TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
oldTerm
where
constantSpecInfoFolder
:: TransformContext
-> Either Term Type
-> RewriteMonad NormalizeState (TransformContext, Either ConstantSpecInfo Type)
constantSpecInfoFolder :: TransformContext
-> Either Term Type
-> RewriteMonad
NormalizeState (TransformContext, Either ConstantSpecInfo Type)
constantSpecInfoFolder TransformContext
localCtx (Right Type
typ) =
(TransformContext, Either ConstantSpecInfo Type)
-> RewriteMonad
NormalizeState (TransformContext, Either ConstantSpecInfo Type)
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TransformContext
localCtx, Type -> Either ConstantSpecInfo Type
forall a b. b -> Either a b
Right Type
typ)
constantSpecInfoFolder localCtx :: TransformContext
localCtx@(TransformContext InScopeSet
is0 Context
tfCtx) (Left Term
term) = do
ConstantSpecInfo
specInfo <- TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo TransformContext
localCtx Term
term
let newIds :: [Id]
newIds = ((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst (ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ConstantSpecInfo
specInfo)
let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
newIds
(TransformContext, Either ConstantSpecInfo Type)
-> RewriteMonad
NormalizeState (TransformContext, Either ConstantSpecInfo Type)
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is1 Context
tfCtx, ConstantSpecInfo -> Either ConstantSpecInfo Type
forall a b. a -> Either a b
Left ConstantSpecInfo
specInfo)
constantSpecInfo
:: TransformContext
-> Term
-> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo :: TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo TransformContext
ctx Term
e = do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
e) then
case Term -> (Term, [Either Term Type])
collectArgs Term
e of
(Prim PrimInfo
p, [Either Term Type]
_)
| PrimInfo -> OccName
primName PrimInfo
p OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> OccName
forall a. Show a => a -> OccName
Text.showt 'removedArg ->
ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> ConstantSpecInfo
constantCsr Term
e)
(Term, [Either Term Type])
_ -> TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
else
case Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
e of
(dc :: Term
dc@(Data DataCon
_), [Either Term Type]
args, [TickInfo]
ticks) ->
TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs TransformContext
ctx [TickInfo]
ticks Term
e (Term -> [Either Term Type] -> Term
mkApps Term
dc) [Either Term Type]
args
(prim :: Term
prim@(Prim PrimInfo
_), [Either Term Type]
args, [TickInfo]
ticks) -> do
ConstantSpecInfo
csr <- TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs TransformContext
ctx [TickInfo]
ticks Term
e (Term -> [Either Term Type] -> Term
mkApps Term
prim) [Either Term Type]
args
if [(Id, Term)] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null (ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ConstantSpecInfo
csr) then
ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ConstantSpecInfo
csr
else
TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
(Lam Id
_ Term
_, [Either Term Type]
_, [TickInfo]
_ticks) ->
if Bool -> Bool
not (Term -> Bool
forall a. HasFreeVars a => a -> Bool
isClosed Term
e) then
TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
else
ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> ConstantSpecInfo
constantCsr Term
e)
(var :: Term
var@(Var Id
f), [Either Term Type]
args, [TickInfo]
ticks) -> do
(Id
curF, SrcSpan
_) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra (f :: Type -> Type).
Functor f =>
((Id, SrcSpan) -> f (Id, SrcSpan))
-> RewriteState extra -> f (RewriteState extra)
curFun
Bool
isNonRecGlobVar <- Term -> RewriteMonad NormalizeState Bool
isNonRecursiveGlobalVar Term
e
if Bool
isNonRecGlobVar Bool -> Bool -> Bool
&& Id
f Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
/= Id
curF then do
ConstantSpecInfo
csr <- TransformContext
-> [TickInfo]
-> Term
-> ([Either Term Type] -> Term)
-> [Either Term Type]
-> RewriteMonad NormalizeState ConstantSpecInfo
mergeCsrs TransformContext
ctx [TickInfo]
ticks Term
e (Term -> [Either Term Type] -> Term
mkApps Term
var) [Either Term Type]
args
if [(Id, Term)] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null (ConstantSpecInfo -> [(Id, Term)]
csrNewBindings ConstantSpecInfo
csr) then
ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ConstantSpecInfo
csr
else
TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
else
TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
(Literal Literal
_,[Either Term Type]
_, [TickInfo]
_ticks) ->
ConstantSpecInfo -> RewriteMonad NormalizeState ConstantSpecInfo
forall a. a -> RewriteMonad NormalizeState a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> ConstantSpecInfo
constantCsr Term
e)
(Term, [Either Term Type], [TickInfo])
_ ->
TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
bindCsr TransformContext
ctx Term
e
type CallGraph = VarEnv (VarEnv Word)
collectCallGraphUniques :: CallGraph -> HashSet.HashSet Unique
collectCallGraphUniques :: CallGraph -> HashSet Unique
collectCallGraphUniques CallGraph
cg = [Unique] -> HashSet Unique
forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList ([Unique]
us0 [Unique] -> [Unique] -> [Unique]
forall a. [a] -> [a] -> [a]
++ [Unique]
us1)
where
us0 :: [Unique]
us0 = CallGraph -> [Unique]
forall b. UniqMap b -> [Unique]
UniqMap.keys CallGraph
cg
us1 :: [Unique]
us1 = (UniqMap Word -> [Unique]) -> [UniqMap Word] -> [Unique]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap UniqMap Word -> [Unique]
forall b. UniqMap b -> [Unique]
UniqMap.keys (CallGraph -> [UniqMap Word]
forall b. UniqMap b -> [b]
UniqMap.elems CallGraph
cg)
callGraph
:: BindingMap
-> Id
-> CallGraph
callGraph :: VarEnv (Binding Term) -> Id -> CallGraph
callGraph VarEnv (Binding Term)
bndrs Id
rt = CallGraph -> Unique -> CallGraph
forall {b}.
Num b =>
UniqMap (VarEnv b) -> Unique -> UniqMap (VarEnv b)
go CallGraph
forall a. VarEnv a
emptyVarEnv (Id -> Unique
forall a. Var a -> Unique
varUniq Id
rt)
where
go :: UniqMap (VarEnv b) -> Unique -> UniqMap (VarEnv b)
go UniqMap (VarEnv b)
cg Unique
root
| Maybe (VarEnv b)
Nothing <- Unique -> UniqMap (VarEnv b) -> Maybe (VarEnv b)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
UniqMap.lookup Unique
root UniqMap (VarEnv b)
cg
, Just Binding Term
rootTm <- Unique -> VarEnv (Binding Term) -> Maybe (Binding Term)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
UniqMap.lookup Unique
root VarEnv (Binding Term)
bndrs =
let used :: VarEnv b
used = Fold Term Id
-> (VarEnv b -> VarEnv b -> VarEnv b)
-> VarEnv b
-> (Id -> VarEnv b)
-> Term
-> VarEnv b
forall s a r. Fold s a -> (r -> r -> r) -> r -> (a -> r) -> s -> r
Lens.foldMapByOf (Id -> f Id) -> Term -> f Term
Fold Term Id
globalIds ((b -> b -> b) -> VarEnv b -> VarEnv b -> VarEnv b
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith b -> b -> b
forall a. Num a => a -> a -> a
(+))
VarEnv b
forall a. VarEnv a
emptyVarEnv (Id -> b -> VarEnv b
forall a b. Uniquable a => a -> b -> UniqMap b
`UniqMap.singleton` b
1) (Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
rootTm)
cg' :: UniqMap (VarEnv b)
cg' = Unique -> VarEnv b -> UniqMap (VarEnv b) -> UniqMap (VarEnv b)
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
UniqMap.insert Unique
root VarEnv b
used UniqMap (VarEnv b)
cg
in (UniqMap (VarEnv b) -> Unique -> UniqMap (VarEnv b))
-> UniqMap (VarEnv b) -> [Unique] -> UniqMap (VarEnv b)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' UniqMap (VarEnv b) -> Unique -> UniqMap (VarEnv b)
go UniqMap (VarEnv b)
cg' (VarEnv b -> [Unique]
forall b. UniqMap b -> [Unique]
UniqMap.keys VarEnv b
used)
go UniqMap (VarEnv b)
cg Unique
_ = UniqMap (VarEnv b)
cg
classifyFunction
:: Term
-> TermClassification
classifyFunction :: Term -> TermClassification
classifyFunction = TermClassification -> Term -> TermClassification
go (Unique -> Unique -> Unique -> TermClassification
TermClassification Unique
0 Unique
0 Unique
0)
where
go :: TermClassification -> Term -> TermClassification
go !TermClassification
c (Lam Id
_ Term
e) = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
go !TermClassification
c (TyLam TyVar
_ Term
e) = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
go !TermClassification
c (Letrec [(Id, Term)]
bs Term
_) = (TermClassification -> Term -> TermClassification)
-> TermClassification -> [Term] -> TermClassification
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' TermClassification -> Term -> TermClassification
go TermClassification
c (((Id, Term) -> Term) -> [(Id, Term)] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Term
forall a b. (a, b) -> b
snd [(Id, Term)]
bs)
go !TermClassification
c e :: Term
e@(App {}) = case (Term, [Either Term Type]) -> Term
forall a b. (a, b) -> a
fst (Term -> (Term, [Either Term Type])
collectArgs Term
e) of
Prim {} -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Unique -> Identity Unique)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Unique
primitive ((Unique -> Identity Unique)
-> TermClassification -> Identity TermClassification)
-> Unique -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Unique
1
Var {} -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Unique -> Identity Unique)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Unique
function ((Unique -> Identity Unique)
-> TermClassification -> Identity TermClassification)
-> Unique -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Unique
1
Term
_ -> TermClassification
c
go !TermClassification
c (Case Term
_ Type
_ [Alt]
alts) = case [Alt]
alts of
(Alt
_:Alt
_:[Alt]
_) -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Unique -> Identity Unique)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Unique
selection ((Unique -> Identity Unique)
-> TermClassification -> Identity TermClassification)
-> Unique -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ Unique
1
[Alt]
_ -> TermClassification
c
go !TermClassification
c (Tick TickInfo
_ Term
e) = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
go TermClassification
c Term
_ = TermClassification
c
isCheapFunction
:: Term
-> Bool
isCheapFunction :: Term -> Bool
isCheapFunction Term
tm = case Term -> TermClassification
classifyFunction Term
tm of
TermClassification {Unique
_function :: Unique
_primitive :: Unique
_selection :: Unique
_function :: TermClassification -> Unique
_primitive :: TermClassification -> Unique
_selection :: TermClassification -> Unique
..}
| Unique
_function Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
1 -> Unique
_primitive Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
0 Bool -> Bool -> Bool
&& Unique
_selection Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
0
| Unique
_primitive Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
1 -> Unique
_function Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
0 Bool -> Bool -> Bool
&& Unique
_selection Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
0
| Unique
_selection Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
1 -> Unique
_function Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
0 Bool -> Bool -> Bool
&& Unique
_primitive Unique -> Unique -> Bool
forall a. Ord a => a -> a -> Bool
<= Unique
0
| Bool
otherwise -> Bool
False
normalizeTopLvlBndr
:: Bool
-> Id
-> Binding Term
-> NormalizeSession (Binding Term)
normalizeTopLvlBndr :: Bool -> Id -> Binding Term -> NormalizeSession (Binding Term)
normalizeTopLvlBndr Bool
isTop Id
nm (Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
tm Bool
_) = Id
-> Lens' (RewriteState NormalizeState) (VarEnv (Binding Term))
-> NormalizeSession (Binding Term)
-> NormalizeSession (Binding Term)
forall s (m :: Type -> Type) k v.
(MonadState s m, Uniquable k) =>
k -> Lens' s (UniqMap v) -> m v -> m v
makeCachedU Id
nm ((NormalizeState -> f NormalizeState)
-> RewriteState NormalizeState -> f (RewriteState NormalizeState)
forall extra1 extra2 (f :: Type -> Type).
Functor f =>
(extra1 -> f extra2)
-> RewriteState extra1 -> f (RewriteState extra2)
extra((NormalizeState -> f NormalizeState)
-> RewriteState NormalizeState -> f (RewriteState NormalizeState))
-> ((VarEnv (Binding Term) -> f (VarEnv (Binding Term)))
-> NormalizeState -> f NormalizeState)
-> (VarEnv (Binding Term) -> f (VarEnv (Binding Term)))
-> RewriteState NormalizeState
-> f (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv (Binding Term) -> f (VarEnv (Binding Term)))
-> NormalizeState -> f NormalizeState
Lens' NormalizeState (VarEnv (Binding Term))
normalized) (NormalizeSession (Binding Term)
-> NormalizeSession (Binding Term))
-> NormalizeSession (Binding Term)
-> NormalizeSession (Binding Term)
forall a b. (a -> b) -> a -> b
$ do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
let nmS :: [Char]
nmS = Name Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)
let tm1 :: Term
tm1 = HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
emptyInScopeSet Term
tm
tm2 :: Term
tm2 = if Bool
isTop then Term -> Term
substWithTyEq Term
tm1 else Term
tm1
(Id, SrcSpan)
old <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra (f :: Type -> Type).
Functor f =>
((Id, SrcSpan) -> f (Id, SrcSpan))
-> RewriteState extra -> f (RewriteState extra)
curFun
Term
tm3 <- ([Char], NormRewrite)
-> ([Char], Term) -> (Id, SrcSpan) -> NormalizeSession Term
rewriteExpr ([Char]
"normalization",NormRewrite
normalization) ([Char]
nmS,Term
tm2) (Id
nm',SrcSpan
sp)
((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra (f :: Type -> Type).
Functor f =>
((Id, SrcSpan) -> f (Id, SrcSpan))
-> RewriteState extra -> f (RewriteState extra)
curFun (((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> (Id, SrcSpan) -> RewriteMonad NormalizeState ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= (Id, SrcSpan)
old
let ty' :: Type
ty' = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
tm3
let r' :: Bool
r' = Id
nm' Id -> Term -> Bool
`globalIdOccursIn` Term
tm3
Binding Term -> NormalizeSession (Binding Term)
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
-> SrcSpan -> InlineSpec -> IsPrim -> Term -> Bool -> Binding Term
forall a.
Id -> SrcSpan -> InlineSpec -> IsPrim -> a -> Bool -> Binding a
Binding Id
nm'{varType = ty'} SrcSpan
sp InlineSpec
inl IsPrim
pr Term
tm3 Bool
r')
substWithTyEq
:: Term
-> Term
substWithTyEq :: Term -> Term
substWithTyEq Term
e0 = [Either Id TyVar] -> Bool -> Term -> Term
go [] Bool
False Term
e0
where
go
:: [Either Id TyVar]
-> Bool
-> Term
-> Term
go :: [Either Id TyVar] -> Bool -> Term -> Term
go [Either Id TyVar]
args Bool
changed (TyLam TyVar
tv Term
e) = [Either Id TyVar] -> Bool -> Term -> Term
go (TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right TyVar
tv Either Id TyVar -> [Either Id TyVar] -> [Either Id TyVar]
forall a. a -> [a] -> [a]
: [Either Id TyVar]
args) Bool
changed Term
e
go [Either Id TyVar]
args Bool
changed (Lam Id
v Term
e)
| TyConApp (TyConName -> Unique
forall a. Name a -> Unique
nameUniq -> Unique
tcUniq) ([Type] -> Maybe (TyVar, Type)
tvFirst -> Just (TyVar
tv, Type
ty)) <- Type -> TypeView
tyView (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
v)
, Unique
tcUniq Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique -> Unique
getKey Unique
eqTyConKey
, TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right TyVar
tv Either Id TyVar -> [Either Id TyVar] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Either Id TyVar]
args
= let
tvs :: [TyVar]
tvs = [Either Id TyVar] -> [TyVar]
forall a b. [Either a b] -> [b]
rights [Either Id TyVar]
args
subst0 :: Subst
subst0 = Subst -> TyVar -> Type -> Subst
extendTvSubst (InScopeSet -> Subst
mkSubst (InScopeSet -> Subst) -> InScopeSet -> Subst
forall a b. (a -> b) -> a -> b
$ VarSet -> InScopeSet
mkInScopeSet (VarSet -> InScopeSet) -> VarSet -> InScopeSet
forall a b. (a -> b) -> a -> b
$ [TyVar] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet [TyVar]
tvs) TyVar
tv Type
ty
removedTy :: Type
removedTy = HasCallStack => Subst -> Type -> Type
Subst -> Type -> Type
substTy Subst
subst0 (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
v
subst1 :: Subst
subst1 = Subst -> Id -> Term -> Subst
extendIdSubst Subst
subst0 Id
v (Term -> Type -> Term
TyApp (PrimInfo -> Term
Prim PrimInfo
removedArg) Type
removedTy)
in [Either Id TyVar] -> Bool -> Term -> Term
go (Id -> Either Id TyVar
forall a b. a -> Either a b
Left (HasCallStack => Subst -> Id -> Id
Subst -> Id -> Id
substId Subst
subst0 Id
v) Either Id TyVar -> [Either Id TyVar] -> [Either Id TyVar]
forall a. a -> [a] -> [a]
: ([Either Id TyVar]
args [Either Id TyVar] -> [Either Id TyVar] -> [Either Id TyVar]
forall a. Eq a => [a] -> [a] -> [a]
List.\\ [TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right TyVar
tv])) Bool
True (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substWithTyEq e" Subst
subst1 Term
e)
| Bool
otherwise = [Either Id TyVar] -> Bool -> Term -> Term
go (Id -> Either Id TyVar
forall a b. a -> Either a b
Left Id
v Either Id TyVar -> [Either Id TyVar] -> [Either Id TyVar]
forall a. a -> [a] -> [a]
: [Either Id TyVar]
args) Bool
changed Term
e
go [Either Id TyVar]
args Bool
True Term
e = Term -> [Either Id TyVar] -> Term
mkAbstraction Term
e ([Either Id TyVar] -> [Either Id TyVar]
forall a. [a] -> [a]
reverse [Either Id TyVar]
args)
go [Either Id TyVar]
_ Bool
False Term
_ = Term
e0
tvFirst :: [Type] -> Maybe (TyVar, Type)
tvFirst :: [Type] -> Maybe (TyVar, Type)
tvFirst [Type
_, VarTy TyVar
tv, Type
ty] = (TyVar, Type) -> Maybe (TyVar, Type)
forall a. a -> Maybe a
Just (TyVar
tv, Type
ty)
tvFirst [Type
_, Type
ty, VarTy TyVar
tv] = (TyVar, Type) -> Maybe (TyVar, Type)
forall a. a -> Maybe a
Just (TyVar
tv, Type
ty)
tvFirst [Type]
_ = Maybe (TyVar, Type)
forall a. Maybe a
Nothing
tvSubstWithTyEq
:: Type
-> Type
tvSubstWithTyEq :: Type -> Type
tvSubstWithTyEq Type
ty0 = [Either TyVar Type] -> Bool -> Type -> Type
go [] Bool
False Type
ty0
where
go :: [Either TyVar Type] -> Bool -> Type -> Type
go :: [Either TyVar Type] -> Bool -> Type -> Type
go [Either TyVar Type]
argsOut Bool
changed (ForAllTy TyVar
tv Type
ty)
= [Either TyVar Type] -> Bool -> Type -> Type
go (TyVar -> Either TyVar Type
forall a b. a -> Either a b
Left TyVar
tvEither TyVar Type -> [Either TyVar Type] -> [Either TyVar Type]
forall a. a -> [a] -> [a]
:[Either TyVar Type]
argsOut) Bool
changed Type
ty
go [Either TyVar Type]
argsOut Bool
changed (Type -> TypeView
tyView -> FunTy Type
arg Type
tyRes)
| Just (TyConName
tc,[Type]
tcArgs) <- Type -> Maybe (TyConName, [Type])
splitTyConAppM Type
arg
, TyConName -> Unique
forall a. Name a -> Unique
nameUniq TyConName
tc Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique -> Unique
getKey Unique
eqTyConKey
, Just (TyVar
tv,Type
ty) <- [Type] -> Maybe (TyVar, Type)
tvFirst [Type]
tcArgs
= let
argsOut2 :: [Either TyVar Type]
argsOut2 = Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
arg Either TyVar Type -> [Either TyVar Type] -> [Either TyVar Type]
forall a. a -> [a] -> [a]
: ([Either TyVar Type]
argsOut [Either TyVar Type] -> [Either TyVar Type] -> [Either TyVar Type]
forall a. Eq a => [a] -> [a] -> [a]
List.\\ [TyVar -> Either TyVar Type
forall a b. a -> Either a b
Left TyVar
tv])
subst :: Subst
subst = Subst -> TyVar -> Type -> Subst
extendTvSubst (InScopeSet -> Subst
mkSubst (InScopeSet -> Subst) -> InScopeSet -> Subst
forall a b. (a -> b) -> a -> b
$ VarSet -> InScopeSet
mkInScopeSet (VarSet -> InScopeSet) -> VarSet -> InScopeSet
forall a b. (a -> b) -> a -> b
$ [TyVar] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet ([TyVar] -> VarSet) -> [TyVar] -> VarSet
forall a b. (a -> b) -> a -> b
$ [Either TyVar Type] -> [TyVar]
forall a b. [Either a b] -> [a]
lefts [Either TyVar Type]
argsOut2) TyVar
tv Type
ty
in [Either TyVar Type] -> Bool -> Type -> Type
go [Either TyVar Type]
argsOut2 Bool
True (HasCallStack => Subst -> Type -> Type
Subst -> Type -> Type
substTy Subst
subst Type
tyRes)
| Bool
otherwise = [Either TyVar Type] -> Bool -> Type -> Type
go (Type -> Either TyVar Type
forall a b. b -> Either a b
Right Type
arg Either TyVar Type -> [Either TyVar Type] -> [Either TyVar Type]
forall a. a -> [a] -> [a]
: [Either TyVar Type]
argsOut) Bool
changed Type
tyRes
go [Either TyVar Type]
_ Bool
False Type
_ = Type
ty0
go [Either TyVar Type]
argsOut Bool
True Type
tyRes = Type -> [Either TyVar Type] -> Type
mkPolyFunTy Type
tyRes ([Either TyVar Type] -> [Either TyVar Type]
forall a. [a] -> [a]
reverse [Either TyVar Type]
argsOut)
rewriteExpr :: (String,NormRewrite)
-> (String,Term)
-> (Id, SrcSpan)
-> NormalizeSession Term
rewriteExpr :: ([Char], NormRewrite)
-> ([Char], Term) -> (Id, SrcSpan) -> NormalizeSession Term
rewriteExpr ([Char]
nrwS,NormRewrite
nrw) ([Char]
bndrS,Term
expr) (Id
nm, SrcSpan
sp) = do
((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra (f :: Type -> Type).
Functor f =>
((Id, SrcSpan) -> f (Id, SrcSpan))
-> RewriteState extra -> f (RewriteState extra)
curFun (((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> (Id, SrcSpan) -> RewriteMonad NormalizeState ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= (Id
nm, SrcSpan
sp)
DebugOpts
opts <- Getting DebugOpts RewriteEnv DebugOpts
-> RewriteMonad NormalizeState DebugOpts
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugOpts RewriteEnv DebugOpts
Getter RewriteEnv DebugOpts
debugOpts
let before :: [Char]
before = Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr Term
expr
let expr' :: Term
expr' = Bool -> [Char] -> Term -> Term
forall a. Bool -> [Char] -> a -> a
traceIf (TransformationInfo -> DebugOpts -> Bool
hasTransformationInfo TransformationInfo
FinalTerm DebugOpts
opts)
([Char]
bndrS [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" before " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
nrwS [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
":\n\n" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
before [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\n")
Term
expr
Term
rewritten <- [Char]
-> InScopeSet -> NormRewrite -> Term -> NormalizeSession Term
forall extra.
[Char]
-> InScopeSet -> Rewrite extra -> Term -> RewriteMonad extra Term
runRewrite [Char]
nrwS InScopeSet
emptyInScopeSet NormRewrite
nrw Term
expr'
let after :: [Char]
after = Term -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr Term
rewritten
Bool -> [Char] -> NormalizeSession Term -> NormalizeSession Term
forall a. Bool -> [Char] -> a -> a
traceIf (TransformationInfo -> DebugOpts -> Bool
hasTransformationInfo TransformationInfo
FinalTerm DebugOpts
opts)
([Char]
bndrS [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" after " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
nrwS [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
":\n\n" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
after [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\n") (NormalizeSession Term -> NormalizeSession Term)
-> NormalizeSession Term -> NormalizeSession Term
forall a b. (a -> b) -> a -> b
$
Term -> NormalizeSession Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
rewritten
mkInlineTick :: Id -> TickInfo
mkInlineTick :: Id -> TickInfo
mkInlineTick Id
n = NameMod -> Type -> TickInfo
NameMod NameMod
PrefixName (LitTy -> Type
LitTy (LitTy -> Type) -> ([Char] -> LitTy) -> [Char] -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> LitTy
SymTy ([Char] -> Type) -> [Char] -> Type
forall a b. (a -> b) -> a -> b
$ Id -> [Char]
forall {a}. Var a -> [Char]
toStr Id
n)
where
toStr :: Var a -> [Char]
toStr = OccName -> [Char]
Text.unpack (OccName -> [Char]) -> (Var a -> OccName) -> Var a -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (OccName, OccName) -> OccName
forall a b. (a, b) -> b
snd ((OccName, OccName) -> OccName)
-> (Var a -> (OccName, OccName)) -> Var a -> OccName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => OccName -> OccName -> (OccName, OccName)
OccName -> OccName -> (OccName, OccName)
Text.breakOnEnd OccName
"." (OccName -> (OccName, OccName))
-> (Var a -> OccName) -> Var a -> (OccName, OccName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name a -> OccName
forall a. Name a -> OccName
nameOcc (Name a -> OccName) -> (Var a -> Name a) -> Var a -> OccName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var a -> Name a
forall a. Var a -> Name a
varName