{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module AlBhed.Tree
( -- * Types
RedBlackTree
, TreeZipper
-- * Constructors
, empty
-- * Manipulations
, insert
, remove
-- * Search
, elem
, search
-- * Zipper
, fromZipper
, left
, right
, toZipper
, up
) where
import Prelude hiding (elem)
data Color = Red | Black deriving (Show, Eq)
data RedBlackTree a
= Node Color (RedBlackTree a) a (RedBlackTree a)
| Leaf
deriving (Show, Eq)
empty :: RedBlackTree a
empty = Leaf
insert :: Ord a => a -> RedBlackTree a -> RedBlackTree a
insert x = makeBlack . go
where
go node@(Node color left value right)
| value > x = balance color (go left) value right
| value == x = node
| otherwise = balance color left value (go right)
go Leaf = Node Red Leaf x Leaf
makeBlack (Node _ l x r) = Node Black l x r
makeBlack x = error "Was expecting a node"
balance :: Color -> RedBlackTree a -> a -> RedBlackTree a -> RedBlackTree a
balance Black (Node Red (Node Red a x b) y c) z d = Node Red (Node Black a x b) y (Node Black c z d)
balance Black (Node Red a x (Node Red b y c)) z d = Node Red (Node Black a x b) y (Node Black c z d)
balance Black a x (Node Red (Node Red b y c) z d) = Node Red (Node Black a x b) y (Node Black c z d)
balance Black a x (Node Red b y (Node Red c z d)) = Node Red (Node Black a x b) y (Node Black c z d)
balance color l x r = Node color l x r
elem :: Ord a => a -> RedBlackTree a -> Bool
elem _ Leaf = False
elem x node@(Node _ left value right)
| x < value = elem x left
| x == value = True
| otherwise = elem x right
data Direction = L | R deriving Show
type TreeZipper a = ([(Direction, RedBlackTree a)], RedBlackTree a)
toZipper :: RedBlackTree a -> TreeZipper a
toZipper = (,) []
left :: TreeZipper a -> TreeZipper a
left (ctx, node@(Node _ l _ _)) = (ctx ++ [(L, node)], l)
left (ctx, Leaf) = (ctx, Leaf)
right :: TreeZipper a -> TreeZipper a
right (ctx, node@(Node _ _ _ r)) = (ctx ++ [(R, node)], r)
right (ctx, Leaf) = (ctx, Leaf)
up :: TreeZipper a -> TreeZipper a
up ([], node) = ([], node)
up (xs, node) =
let (dir, Node c l v r) = last xs
in case dir of
L -> (init xs, Node c node v r)
R -> (init xs, Node c l v node)
fromZipper :: TreeZipper a -> RedBlackTree a
fromZipper = go
where
go ([], node) = node
go z = go $ up z
search :: forall a. Ord a => a -> RedBlackTree a -> Maybe (TreeZipper a)
search x root =
let s = go . toZipper $ root
in case snd s of
Leaf -> Nothing
_ -> Just s
where
go :: Ord a => TreeZipper a -> TreeZipper a
go (path, Leaf) = (path, Leaf)
go z@(path, Node _ l value r)
| x < value = go $ left z
| x == value = z
| x > value = go $ right z
remove :: (Ord a) => a -> RedBlackTree a -> RedBlackTree a
remove x node@(Node _ Leaf value Leaf)
| x == value = Leaf
| otherwise = node
remove x node = maybe node remove' $ search x node
remove' :: (Ord a) => TreeZipper a -> RedBlackTree a
remove' (path, node) = go $ identifyCase (path, node)
where
go :: Case a -> RedBlackTree a
go D1 = undefined
go (D4 (Node pc pl pv pr) (Node sc cl cv cr)) = error "To be implemented"
go _ = undefined
-- Maybe this is a candidate for GADTS?
data Case a
= D1
| D2
| D3
| D4 (RedBlackTree a) (RedBlackTree a)
| D5
| D6
deriving (Show)
identifyCase :: TreeZipper a -> Case a
identifyCase zipper =
let colors = map (colorOf . snd . ($ zipper)) [parent, sibling, distant_nephew, close_nephew]
in case colors of
[Red, Black, Black, Black] -> D4 (snd $ parent zipper) (snd $ sibling zipper)
where
colorOf Leaf = Black
colorOf (Node c _ _ _) = c
ourDirection = fst . last . fst $ zipper
sibling z = case ourDirection of
L -> right . parent $ z
R -> left . parent $ z
parent = up
distant_nephew z = case ourDirection of
L -> right . sibling $ z
R -> left . sibling $ z
close_nephew z = case ourDirection of
L -> left . sibling $ z
R -> right . sibling $ z