fn main() {
    use std::io::BufRead;
    let mut grid: Grid<u8> = Grid::with_outside(u8::MAX);
    let filename = std::env::args().nth(1).expect("Expected filename");
    let file = std::io::BufReader::new(
        std::fs::File::open(<String as AsRef<std::path::Path>>::as_ref(
            &filename,
        ))
        .unwrap(),
    );
    for (y, line) in file.lines().enumerate() {
        for (x, c) in line.unwrap().chars().enumerate() {
            grid[(x, y)] = c.to_digit(10).unwrap_or(u8::MAX as u32) as u8;
        }
    }
    println!(
        "Risk level: {}",
        grid.minima().map(|p| grid[p] as usize + 1).sum::<usize>()
    );
    let basins = grid.basins().iter().map(|(_, s)| *s).top_n(3);
    println!(
        "Top 3 basin sizes: {:?}; product: {}",
        basins,
        basins.iter().product::<usize>()
    );
}

#[derive(Debug)]
struct Grid<T> {
    inner: Vec<T>,
    outside: T,
    width: usize,
}

impl<T> Grid<T> {
    fn with_outside(outside: T) -> Self {
        Self {
            inner: Vec::new(),
            outside,
            width: 0,
        }
    }

    fn height(&self) -> usize {
        self.inner.len() / self.width.max(1)
    }

    fn width(&self) -> usize {
        self.width
    }

    fn neighbours(&self, point: (usize, usize)) -> Neighbours {
        Neighbours::new(point, (self.width(), self.height()))
    }

    fn points(&self) -> GridPoints {
        GridPoints::new(self.width, self.height())
    }

    fn map_with_coords<S, F: FnMut((usize, usize), &T) -> S>(
        &self,
        mut f: F,
    ) -> Grid<S> {
        let mut inner = Vec::with_capacity(self.inner.len());
        let outside = f((usize::MAX, usize::MAX), &self.outside);
        inner
            .extend(self.points().zip(self.inner.iter()).map(|(p, v)| f(p, v)));
        Grid {
            inner,
            outside,
            width: self.width,
        }
    }
}

impl<T: Ord> Grid<T> {
    fn minima(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
        self.points().filter(move |p| {
            let c = &self[*p];
            let mut r = false;
            for n in self.neighbours(*p).map(|n| &self[n]) {
                match c.cmp(n) {
                    std::cmp::Ordering::Less => {
                        r = true;
                    }
                    std::cmp::Ordering::Equal => {}
                    std::cmp::Ordering::Greater => {
                        r = false;
                        break;
                    }
                }
            }
            r
        })
    }
}

impl<T: Ord + From<u8> + PartialEq<T>> Grid<T> {
    fn basins(&self) -> Vec<((usize, usize), usize)> {
        let mut uf = self
            .map_with_coords(|p, n| (p, if n == &T::from(9) { 0 } else { 1 }));
        for (p, h) in uf
            .points()
            .zip(self.inner.iter())
            .filter(|(_, h)| *h < &T::from(9))
        {
            // I originally thought this would also find the minimum of each
            // basin but it doesn't always.
            for (n, j) in uf
                .neighbours(p)
                .map(|n| (n, &self[n]))
                .filter(|(_, n)| *n < &T::from(9))
            {
                match h.cmp(j) {
                    std::cmp::Ordering::Less => uf.flow(n, p),
                    std::cmp::Ordering::Equal => (),
                    std::cmp::Ordering::Greater => uf.flow(p, n),
                }
            }
        }
        uf.points()
            .zip(uf.inner.iter().map(|(_, s)| s))
            .filter(|(_, s)| *s > &0)
            .map(|(p, s)| (p, *s))
            .collect()
    }
}

impl<T: Copy> Grid<((usize, usize), T)> {
    fn u_find(&mut self, p: (usize, usize)) -> (usize, usize) {
        let d = unsafe { &mut *(&mut self[p].0 as *mut (usize, usize)) };
        if d == &p {
            p
        } else {
            let q = self.u_find(*d);
            *d = q;
            q
        }
    }
}

impl<T: Copy + From<u8> + std::ops::Add<T, Output = T>>
    Grid<((usize, usize), T)>
{
    fn flow(&mut self, mut s: (usize, usize), mut t: (usize, usize)) {
        s = self.u_find(s);
        t = self.u_find(t);
        if s != t {
            let y = unsafe { &mut *(&mut self[t].1 as *mut T) };
            let (d, x) = &mut self[s];
            let x = std::mem::replace(x, T::from(0));
            *y = *y + x;
            *d = t;
        }
    }
}

impl<T> std::ops::Index<(usize, usize)> for Grid<T> {
    type Output = T;

    fn index(&self, (x, y): (usize, usize)) -> &Self::Output {
        if x >= self.width {
            &self.outside
        } else {
            let i = y * self.width + x;
            self.inner.get(i).unwrap_or(&self.outside)
        }
    }
}

impl<T: Clone> std::ops::IndexMut<(usize, usize)> for Grid<T> {
    fn index_mut(&mut self, (x, y): (usize, usize)) -> &mut Self::Output {
        let new_width = self.width().max(x + 1);
        let new_height = self.height().max(y + 1);
        if new_width > self.width() || new_height > self.height() {
            self.inner = (0..new_width * new_height)
                .map(|i| self[(i % new_width, i / new_width)].clone())
                .collect();
            self.width = new_width;
        }
        self.inner.index_mut(y * self.width + x)
    }
}

trait IterExtra: Iterator {
    fn top_n(&mut self, size: usize) -> Vec<Self::Item>
    where
        Self::Item: PartialOrd,
    {
        let mut result: Vec<_> = self.take(size).collect();
        for mut i in self {
            let i = &mut i;
            for j in result.iter_mut() {
                if i > j {
                    std::mem::swap(i, j);
                }
            }
        }
        result
    }
}

impl<I: Iterator> IterExtra for I {}

struct GridPoints {
    width: usize,
    height: usize,
    x: usize,
    y: usize,
}

impl GridPoints {
    fn new(width: usize, height: usize) -> Self {
        Self {
            width,
            height,
            x: 0,
            y: 0,
        }
    }
}

impl Iterator for GridPoints {
    type Item = (usize, usize);

    fn next(&mut self) -> Option<Self::Item> {
        let x = self.x;
        let y = self.y;
        if y < self.height {
            let x_n = x + 1;
            if x_n >= self.width {
                self.x = 0;
                self.y = y + 1;
            } else {
                self.x = x_n;
            }
            Some((x, y))
        } else {
            None
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        let n = self.width * (self.height - self.y) - self.x;
        (n, Some(n))
    }
}

struct Neighbours {
    centre: (usize, usize),
    dimensions: (usize, usize),
    i: u8,
}

impl Neighbours {
    fn new(centre: (usize, usize), dimensions: (usize, usize)) -> Self {
        Self {
            centre,
            dimensions,
            i: 0,
        }
    }
}

impl Iterator for Neighbours {
    type Item = (usize, usize);

    fn next(&mut self) -> Option<Self::Item> {
        match self.i {
            0 => {
                self.i = 1;
                if self.centre.1 > 0 {
                    Some((self.centre.0, self.centre.1 - 1))
                } else {
                    self.next()
                }
            }
            1 => {
                self.i = 2;
                let x = self.centre.0 + 1;
                if x < self.dimensions.0 {
                    Some((x, self.centre.1))
                } else {
                    self.next()
                }
            }
            2 => {
                self.i = 3;
                let y = self.centre.1 + 1;
                if y < self.dimensions.1 {
                    Some((self.centre.0, y))
                } else {
                    self.next()
                }
            }
            3 => {
                self.i = 4;
                if self.centre.0 > 0 {
                    Some((self.centre.0 - 1, self.centre.1))
                } else {
                    self.next()
                }
            }
            _ => None,
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        (0, Some((4 - self.i).into()))
    }
}