{-# LANGUAGE MultiWayIf #-}
import Codec.Phaser
import Codec.Phaser.Common
import Codec.Phaser.Core (eof)
import Codec.Phaser.UTF8
import Control.Applicative
import Data.Char
import Data.Functor.Compose
import Data.Foldable
import Data.List
import qualified Data.Map as M
import qualified Data.Sequence as S
import System.Environment

newtype Predicate a = Predicate { getPredicate :: a -> Bool }

instance Semigroup (Predicate a) where
  Predicate a <> Predicate b = Predicate (\x -> a x || b x)

instance Monoid (Predicate a) where
  mempty = Predicate (const False)

validity :: Monoid p => Phase p Char o (M.Map String (Integer -> Bool))
validity = mconcat <$> sepBy (M.singleton <$>
  munch (/= ':') <*>
  (char ':' *> munch isSpace *> (getPredicate . foldMap Predicate <$> sepBy
    ((\a b t -> t >= a && t <= b) <$> positiveIntegerDecimal <*>
      (char '-' *> positiveIntegerDecimal)
     )
    (munch isSpace *> string "or" *> munch isSpace)
   ))
 ) (char '\n')

ticket :: Monoid p => Phase p Char o (S.Seq Integer)
ticket = S.fromList <$> sepBy positiveIntegerDecimal (char ',')

data Input = Input {
   ranges :: M.Map String (Integer -> Bool),
   myTicket :: S.Seq Integer,
   nearbyTickets :: [S.Seq Integer]
 }

input :: Monoid p => Phase p Char o Input
input = Input <$>
  validity <*>
  (munch isSpace *> string "your ticket:" *> munch isSpace *> ticket) <*>
  (munch isSpace *> string "nearby tickets:" *> munch isSpace *> sepBy ticket (char '\n') <* (() <$ char '\n' <|> eof))

errorRate :: Input -> Integer
errorRate s = let
  t = not . getPredicate (foldMap Predicate $ ranges s)
  in sum $ Compose $ map (S.filter t) $ nearbyTickets s

onlyValid :: Input -> Input
onlyValid s = let
  t = getPredicate (foldMap Predicate $ ranges s)
  in s {
    nearbyTickets = filter ((&&) <$> (not . S.null) <*> all t) $ nearbyTickets s
   }

columns :: Input -> [M.Map String Int]
columns s = let
  unknown = M.keys (ranges s) <$ myTicket s
  reduced = foldl' (S.zipWith $ \c v ->
    filter (\c' -> case M.lookup c' (ranges s) of
      Nothing -> False
      Just f -> f v
     ) c
   ) unknown $ nearbyTickets s
  solve :: S.Seq [String] -> [S.Seq [String]]
  solve s = let
    known = do
      [c] <- foldMap (:[]) s 
      return c
    in if
      | known /= nub known -> []
      | length known == S.length s -> [s]
      | otherwise -> let
        s1 :: S.Seq [String]
        s1 = fmap (\c -> case c of
          [_] -> c
          _ -> filter (not . flip elem known) c
         ) s
        in if s1 == s
          then let
            narrowest = foldr (\c r n -> case c of
              [_] -> r (n + 1)
              _ -> case r (n + 1) of
                Just (n', c')
                  | length c' < length c -> Just (n', c')
                _ -> Just (n, c)
             ) (const Nothing) s1 0
            in case narrowest of
              Nothing -> solve s1
              Just (i, c) -> c >>= \c' -> solve $ S.update i [c'] s1
          else solve s1
  in map (fold . S.mapWithIndex (\i [c] -> M.singleton c i)) $ solve reduced

departureFields :: M.Map String Int -> S.Seq Integer -> M.Map String Integer
departureFields cols tk = M.fromList $ do
  (k,i) <- M.toList cols
  if "departure" `isPrefixOf` k
    then [()]
    else []
  Just v <- [S.lookup i tk]
  return (k,v)

main = do
  [fn] <- getArgs
  parseResult <- parseFile (utf8_stream >># trackPosition >># input) fn
  case parseResult of
    Right (s:_) -> do
      putStrLn $ "Part 1: " ++ show (errorRate s)
      let
        s1 = onlyValid s
        cols = head $ columns s1
      putStrLn $ "Part 2: " ++ show (product $ departureFields cols $ myTicket s1)
    Left e -> print e