{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TupleSections #-}


module Language.Haskell.Tools.Refactor.Builtin.OrganizeExtensions
  ( module Language.Haskell.Tools.Refactor.Builtin.OrganizeExtensions
  , module Language.Haskell.Tools.Refactor.Builtin.ExtensionOrganizer.ExtMonad
  ) where

import Language.Haskell.Tools.Refactor.Builtin.ExtensionOrganizer.ExtMonad
import Language.Haskell.Tools.Refactor.Builtin.ExtensionOrganizer.TraverseAST
import Language.Haskell.Tools.Refactor.Builtin.ExtensionOrganizer.SupportedExtensions (isSupported)

import Language.Haskell.Tools.Refactor hiding (LambdaCase)
import Language.Haskell.Tools.Refactor.Utils.Extensions

import GHC (Ghc(..))

import Control.Reference
import Data.Char (isAlphaNum)
import Data.Function (on)
import Data.Maybe (mapMaybe)
import Data.List
import qualified Data.Map.Strict as SMap (empty, toList)

-- NOTE: When working on the entire AST, we should build a monad,
--       that will will avoid unnecessary checks.
--       For example if it already found a record wildcard, it won't check again

--       Pretty easy now. Chcek wheter it is already in the ExtMap.

highlightExtensionsQuery :: QueryChoice
highlightExtensionsQuery = GlobalQuery "HighlightExtensions" extQuery
  where extQuery :: ModuleDom -> [ModuleDom] -> QueryMonad QueryValue
        extQuery (_,m) _ = lift . fmap MarkerQuery . extensionMarkers $ m

organizeExtensionsRefactoring :: RefactoringChoice
organizeExtensionsRefactoring = ModuleRefactoring "OrganizeExtensions" (localRefactoring organizeExtensions)

projectOrganizeExtensionsRefactoring :: RefactoringChoice
projectOrganizeExtensionsRefactoring = ProjectRefactoring "ProjectOrganizeExtensions" projectOrganizeExtensions

projectOrganizeExtensions :: ProjectRefactoring
projectOrganizeExtensions =
  mapM (\(k, m) -> ContentChanged . (k,) <$> localRefactoringRes id m (organizeExtensions m))

tryOut :: String -> String -> IO ()
tryOut = tryRefactor (localRefactoring . const organizeExtensions)

organizeExtensions :: LocalRefactoring
organizeExtensions moduleAST = do
  exts <- liftGhc $ reduceExtensions moduleAST
  let langExts = map (mkLanguagePragma . pure . serializeExt . show) exts
      ghcOpts  = moduleAST ^? filePragmas & annList & opStr & stringNodeStr
      ghcOpts' = map (mkOptionsGHC . unwords . filter (isPrefixOf "-") . words) ghcOpts

      offExts  = map (mkLanguagePragma . pure)
               . sort
               . map (("No" ++) . serializeExt . show)
               . collectTurnedOffExtensions
               $ moduleAST

      newPragmas = mkFilePragmas $ offExts ++ langExts ++ ghcOpts'

  (filePragmas != newPragmas)
    -- remove empty {-# LANGUAGE #-} pragmas
    >=> filePragmas !~ filterListSt (\case LanguagePragma (AnnList []) -> False; _ -> True)
    $ moduleAST

-- | Reduces default extension list (keeps unsupported extensions)
reduceExtensions :: UnnamedModule -> Ghc [Extension]
reduceExtensions moduleAST = do
  let defaults = map replaceDeprecated . collectDefaultExtensions $ moduleAST
      expanded = expandExtensions defaults
      (xs, ys) = partition isSupported expanded

  xs' <- flip execStateT SMap.empty . flip runReaderT xs . traverseModule $ moduleAST
  let filteredExts = nub . mergeImplied $ (determineExtensions xs' ++ ys)
  if any (`elem` filteredExts) [Cpp, TemplateHaskell, TemplateHaskellQuotes, QuasiQuotes]
    -- We can't say anything about generated code
    then return . mergeImplied $ defaults
    -- Merging is needed because there might be unsopported extensions
    -- that are implied by supported extensions (TypeFamilies -> MonoLocalBinds)
    else return . sortBy (compare `on` show) $ filteredExts

-- | Collect the required extensions in a module and returns a markers associated with them
extensionMarkers :: UnnamedModule -> Ghc [Marker]
extensionMarkers = fmap (concatMap toMarkers . SMap.toList) . collectExtensions
  where toMarkers (rel, occs) = map (toMarker rel) occs
        toMarker  rel occ     = Marker (unOcc occ) Info (showWithLevel rel occ)
        showWithLevel rel occ = (head . words . show $ occ) ++ ": " ++ prettyPrintFormula rel

-- | Collects extensions induced by the source code (with location info)
collectExtensions :: UnnamedModule -> Ghc ExtMap
collectExtensions = collectExtensionsWith traverseModule


-- | Collects the required extensions from a module using the given traversal method
collectExtensionsWith :: CheckNode UnnamedModule -> UnnamedModule -> Ghc ExtMap
collectExtensionsWith trvModule moduleAST = do
  let expanded = expandExtensions . collectDefaultExtensions $ moduleAST
  flip execStateT SMap.empty . flip runReaderT expanded . trvModule $ moduleAST


-- | Expands every extension in a list, while not producing any duplicates.
expandExtensions :: [Extension] -> [Extension]
expandExtensions = nub . concatMap expandExtension

-- | Collects extensions enabled by default
collectDefaultExtensions :: UnnamedModule -> [Extension]
collectDefaultExtensions = mapMaybe toExt . getExtensions

-- | Collects extensions enabled by default
collectTurnedOffExtensions :: UnnamedModule -> [Extension]
collectTurnedOffExtensions = mapMaybe (toExt . drop 2)
                           . filter (isPrefixOf "No")
                           . getExtensions

-- | Collects the string representation of the extensions in the module
getExtensions :: UnnamedModule -> [String]
getExtensions = flip (^?) (filePragmas & annList & lpPragmas & annList & langExt)

toExt :: String -> Maybe Extension
toExt str = case map fst . reads . canonExt . takeWhile isAlphaNum $ str of
              e:_ -> Just e
              []  -> fail $ "Extension '" ++ takeWhile isAlphaNum str ++ "' is not known."