{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Language.Haskell.Tools.Refactor.Builtin.InlineBinding
(inlineBinding, tryItOut, inlineBindingRefactoring) where
import Control.Monad.State
import Control.Reference
import Data.Generics.Uniplate.Data ()
import Data.Generics.Uniplate.Operations (Uniplate(..), Biplate(..))
import Data.List (nub)
import Data.Maybe (Maybe(..), catMaybes)
import Name as GHC (NamedThing(..), Name, occNameString)
import SrcLoc as GHC (SrcSpan(..), RealSrcSpan)
import Language.Haskell.Tools.Refactor as AST
inlineBindingRefactoring :: RefactoringChoice
inlineBindingRefactoring = SelectionRefactoring "InlineBinding" inlineBinding
tryItOut :: String -> String -> IO ()
tryItOut = tryRefactor inlineBinding
inlineBinding :: RealSrcSpan -> Refactoring
inlineBinding span namedMod@(_,mod) mods
= let topLevel :: Simple Traversal Module DeclList
topLevel = nodesContaining span
local :: Simple Traversal Module LocalBindList
local = nodesContaining span
exprs :: Simple Traversal Module Expr
exprs = nodesContaining span
elemAccess :: (BindingElem d) => AnnList d -> Maybe ValueBind
elemAccess = getValBindInList span
removed = catMaybes $ map elemAccess (mod ^? topLevel) ++ map elemAccess (mod ^? local)
in case reverse removed of
[] -> refactError "No binding is selected."
removedBinding:_ ->
let [removedBindingName] = nub $ catMaybes $ map semanticsName (removedBinding ^? bindingName)
in if | any (containInlined removedBindingName) mods
-> refactError "Cannot inline the definition, it is used in other modules."
| _:_ <- mod ^? modHead & annJust & mhExports & annJust & biplateRef
& filtered (\n -> semanticsName (n :: QualifiedName) == Just removedBindingName)
-> refactError "Cannot inline the definition, it is present in the export list."
| otherwise -> localRefactoring (inlineBinding' topLevel local exprs removedBinding removedBindingName) namedMod mods
inlineBinding' :: Simple Traversal Module DeclList -> Simple Traversal Module LocalBindList
-> Simple Traversal Module Expr -> ValueBind -> GHC.Name -> LocalRefactoring
inlineBinding' topLevelRef localRef exprRef removedBinding removedBindingName mod
= do replacement <- createReplacement removedBinding
let RealSrcSpan bindingSpan = getRange removedBinding
mod' <- removeBindingAndSig topLevelRef localRef exprRef removedBindingName mod
(mod'', used) <- runStateT (descendBiM (replaceInvocations bindingSpan removedBindingName replacement) mod') False
if not used
then refactError "The selected definition is not used, it can be safely deleted."
else return mod''
containInlined :: GHC.Name -> ModuleDom -> Bool
containInlined name (_,mod)
= any (\qn -> semanticsName qn == Just name) $ (mod ^? biplateRef :: [QualifiedName])
removeBindingAndSig :: Simple Traversal Module DeclList -> Simple Traversal Module LocalBindList
-> Simple Traversal Module Expr -> GHC.Name -> LocalRefactoring
removeBindingAndSig topLevelRef localRef exprRef name
= (return . removeEmptyBnds (topLevelRef & annList & declValBind &+& localRef & annList & localVal) exprRef)
<=< (topLevelRef !~ removeBindingAndSig' name)
<=< (localRef !~ removeBindingAndSig' name)
removeBindingAndSig' :: (SourceInfoTraversal d, BindingElem d)
=> GHC.Name -> AnnList d -> LocalRefactor (AnnList d)
removeBindingAndSig' name ls = do
bnds <- mapM notThatBindOrSig (ls ^? annList)
return $ (annList .- removeNameFromSigBind) (filterListIndexed (\i _ -> bnds !! i) ls)
where notThatBindOrSig e
| Just sb <- e ^? sigBind = return $ nub (map semanticsName (sb ^? tsName & annList & simpleName)) /= [Just name]
| Just vb <- e ^? valBind = do
let isThat = nub (map semanticsName (vb ^? bindingName)) == [Just name]
when isThat (void $ accessRhs !| checkForRecursion name $ vb)
return $ not isThat
| Just fs <- e ^? fixitySig = return $ nub (map semanticsName (fs ^? fixityOperators & annList & operatorName)) /= [Just name]
| otherwise = return True
removeNameFromSigBind
= (sigBind & tsName .- filterList (\n -> semanticsName (n ^. simpleName) /= Just name))
. (fixitySig & fixityOperators .- filterList (\n -> semanticsName (n ^. operatorName) /= Just name))
accessRhs = valBindRhs
&+& valBindLocals & accessLocalRhs
&+& funBindMatches & annList & matchRhs
&+& funBindMatches & annList & (matchRhs &+& matchBinds & accessLocalRhs)
accessLocalRhs = annJust & localBinds & annList & localVal & accessRhs
checkForRecursion :: GHC.Name -> Rhs -> LocalRefactor ()
checkForRecursion n = void . (biplateRef !| checkNameForRecursion n)
checkNameForRecursion :: GHC.Name -> AST.Name -> LocalRefactor ()
checkNameForRecursion name n
| semanticsName (n ^. simpleName) == Just name
= refactError $ "Cannot inline definitions containing direct recursion. Recursive call at: "
++ shortShowSpanWithFile (getRange n)
| otherwise = return ()
replaceInvocations :: RealSrcSpan -> GHC.Name -> ([[GHC.Name]] -> [Expr] -> Expr) -> Expr
-> StateT Bool LocalRefactor Expr
replaceInvocations bindingRange name replacement expr
| (Var n, args) <- splitApps expr
, semanticsName (n ^. simpleName) == Just name
= do put True
replacement (map (map (^. _1)) $ semanticsScope expr)
<$> mapM (descendM (replaceInvocations bindingRange name replacement)) args
| otherwise
= descendM (replaceInvocations bindingRange name replacement) expr
splitApps :: Expr -> (Expr, [Expr])
splitApps (App f a) = case splitApps f of (fun, args) -> (fun, args ++ [a])
splitApps (InfixApp l (NormalOp qn) r) = (mkVar (mkParenName qn), [l,r])
splitApps (InfixApp l (BacktickOp qn) r) = (mkVar (mkNormalName qn), [l,r])
splitApps (Paren expr) = splitApps expr
splitApps expr = (expr, [])
joinApps :: Expr -> [Expr] -> Expr
joinApps f [] = f
joinApps f args = parenIfNeeded (foldl mkApp f args)
createReplacement :: ValueBind -> LocalRefactor ([[GHC.Name]] -> [Expr] -> Expr)
createReplacement (SimpleBind (VarPat _) (UnguardedRhs e) locals)
= return $ \_ args -> joinApps (parenIfNeeded $ wrapLocals locals e) args
createReplacement (SimpleBind _ _ _)
= refactError "Cannot inline, illegal simple bind. Only variable left-hand sides and unguarded right-hand sides are accepted."
createReplacement (FunctionBind (AnnList [Match lhs (UnguardedRhs expr) locals]))
= return $ \_ args -> let (argReplacement, matchedPats, appliedArgs) = matchArguments (getArgsOf lhs) args
in joinApps (parenIfNeeded (createLambda matchedPats (wrapLocals locals (replaceExprs argReplacement expr)))) appliedArgs
where getArgsOf (MatchLhs _ (AnnList args)) = args
getArgsOf (InfixLhs lhs _ rhs (AnnList more)) = lhs:rhs:more
createReplacement (FunctionBind matches)
= return $ \sc args -> let numArgs = getArgNum (head (matches ^? annList & matchLhs)) - length args
newArgs = take numArgs $ map mkName $ filter notInScope $ map (("x" ++ ) . show @Int) [1..]
notInScope str = not $ any (any ((== str) . occNameString . getOccName)) sc
in parenIfNeeded $ createLambda (map mkVarPat newArgs)
$ mkCase (mkTuple $ map mkVar newArgs ++ args)
$ map replaceMatch (matches ^? annList)
where getArgNum (MatchLhs _ (AnnList args)) = length args
getArgNum (InfixLhs _ _ _ (AnnList more)) = length more + 2
replaceExprs :: [(GHC.Name, Expr)] -> Expr -> Expr
replaceExprs [] = id
replaceExprs replaces = (uniplateRef .-) $ \case
Var n | Just name <- semanticsName (n ^. simpleName)
, Just replace <- lookup name replaces
-> replace
e -> e
matchArguments :: [Pattern] -> [Expr] -> ([(GHC.Name, Expr)], [Pattern], [Expr])
matchArguments (ParenPat p : pats) exprs = matchArguments (p:pats) exprs
matchArguments (p:pats) (e:exprs)
| Just replacement <- staticPatternMatch p e
= case matchArguments pats exprs of (replacements, patterns, expressions) -> (replacement ++ replacements, patterns, expressions)
| otherwise
= ([], p:pats, e:exprs)
matchArguments pats [] = ([], pats, [])
matchArguments [] exprs = ([], [], exprs)
staticPatternMatch :: Pattern -> Expr -> Maybe [(GHC.Name, Expr)]
staticPatternMatch (VarPat n) e
| Just name <- semanticsName $ n ^. simpleName
= Just [(name, e)]
staticPatternMatch (AppPat n (AnnList args)) e
| (Var n', exprs) <- splitApps e
, length args == length exprs
&& semanticsName (n ^. simpleName) == semanticsName (n' ^. simpleName)
, Just subs <- sequence $ zipWith staticPatternMatch args exprs
= Just $ concat subs
staticPatternMatch (TuplePat (AnnList pats)) (Tuple (AnnList args))
| length pats == length args
, Just subs <- sequence $ zipWith staticPatternMatch pats args
= Just $ concat subs
staticPatternMatch _ _ = Nothing
replaceMatch :: Match -> Alt
replaceMatch (Match lhs rhs locals) = mkAlt (toPattern lhs) (toAltRhs rhs) (locals ^? annJust)
where toPattern (MatchLhs _ (AnnList pats)) = mkTuplePat pats
toPattern (InfixLhs lhs _ rhs (AnnList more)) = mkTuplePat (lhs:rhs:more)
toAltRhs (UnguardedRhs expr) = mkCaseRhs expr
toAltRhs (GuardedRhss (AnnList rhss)) = mkGuardedCaseRhss (map toAltGuardedRhs rhss)
toAltGuardedRhs (GuardedRhs (AnnList guards) expr) = mkGuardedCaseRhs guards expr
wrapLocals :: MaybeLocalBinds -> Expr -> Expr
wrapLocals bnds = case bnds ^? annJust & localBinds & annList of
[] -> id
localBinds -> mkLet localBinds
compositePat :: Pattern -> Bool
compositePat (AppPat {}) = True
compositePat (InfixAppPat {}) = True
compositePat (TypeSigPat {}) = True
compositePat (ViewPat {}) = True
compositePat _ = False
parenIfNeeded :: Expr -> Expr
parenIfNeeded e = if compositeExprs e then mkParen e else e
compositeExprs :: Expr -> Bool
compositeExprs (App {}) = True
compositeExprs (InfixApp {}) = True
compositeExprs (Lambda {}) = True
compositeExprs (Let {}) = True
compositeExprs (If {}) = True
compositeExprs (Case {}) = True
compositeExprs (Do {}) = True
compositeExprs _ = False
createLambda :: [Pattern] -> Expr -> Expr
createLambda [] = id
createLambda pats = mkLambda (map (\p -> if compositePat p then mkParenPat p else p) pats)