use std::{io::Write, iter::zip};

use rand::{rngs::ThreadRng, seq::IteratorRandom};

#[derive(PartialEq, Clone, Copy, Debug)]
enum State {
    Number(i8),
    Unknown,
    Bomb,
}

fn is_compatible(minefield: &[State], truth: &[State]) -> bool {
    zip(minefield, truth).all(|pair| match pair {
        (State::Number(a), State::Number(b)) if a == b => true,
        (State::Bomb, State::Bomb) => true,
        (State::Unknown, _) => true,
        _ => false,
    })
}

fn explore_l(minefield: &mut [State], truth: &[State], idx: usize) {
    for idx in (0..=idx).rev() {
        minefield[idx] = truth[idx];
        if !(truth[idx] == State::Number(0)) {
            minefield[idx - 1] = State::Bomb;
            break;
        }
    }
}
fn explore_r(minefield: &mut [State], truth: &[State], idx: usize) {
    for idx in idx..minefield.len() {
        minefield[idx] = truth[idx];
        if !(truth[idx] == State::Number(0)) {
            minefield[idx + 1] = State::Bomb;
            break;
        }
    }
}

fn explore(minefield: &mut [State], truth: &[State], idx: usize) {
    match truth[idx] {
        State::Number(0) => {
            explore_l(minefield, truth, idx);
            explore_r(minefield, truth, idx);
        }
        State::Number(1) => {
            minefield[idx] = State::Number(1);
        }
        State::Number(2) => {
            minefield[idx] = State::Number(2);
            minefield[idx - 1] = State::Bomb;
            minefield[idx + 1] = State::Bomb;
        }
        _ => unreachable!(),
    }
}

fn evaluate<'a>(
    rng: &mut ThreadRng,
    minefield: &[State],
    truths: Vec<Vec<State>>,
    num_mines: usize,
    depth: i64,
    max_depth: i64,
) -> (usize, i64) {
    assert!(truths.iter().all(|truth| truth.len() == minefield.len()));
    assert!(truths.len() > 0);
    if minefield
        .iter()
        .filter(|cell| matches!(cell, State::Bomb | State::Unknown))
        .count()
        == num_mines
    {
        return (0, 100000);
    }
    if depth >= max_depth {
        return (0, 100000);
    } else {
        let options = minefield
            .iter()
            .enumerate()
            .filter(|(_, cell)| **cell == State::Unknown)
            .choose_multiple(rng, 10);

        assert!(options.len() > 0);

        if options.len() == 1 {
            return (options[0].0, 100000);
        }

        let mut best_option = 0;
        let mut best_score = i64::MIN;

        for (idx, _) in options {
            let mut score = 0;
            for truth in &truths {
                if truth[idx] != State::Bomb {
                    let mut next = Vec::from(minefield);
                    explore(&mut next, truth, idx);
                    let next_truths: Vec<Vec<State>> = truths
                        .iter()
                        .filter(|truth| is_compatible(&next, truth))
                        .cloned()
                        .collect();
                    score +=
                        evaluate(rng, &mut next, next_truths, num_mines, depth + 1, max_depth).1;
                }
            }
            score /= truths.len() as i64;
            if score > best_score {
                best_score = score;
                best_option = idx;
            }
        }
        return (best_option, best_score);
    }
}

fn generate(rng: &mut ThreadRng, len: usize, num_mines: usize) -> Vec<State> {
    // choose_multiple_mut doesn't exist and idk a better way
    let mut minefield = vec![State::Number(0); len];
    let idxs: Vec<usize> = minefield
        .iter()
        .enumerate()
        .choose_multiple(rng, num_mines)
        .iter()
        .map(|(idx, _)| *idx)
        .collect();

    for idx in idxs {
        minefield[idx] = State::Bomb;
        if idx > 0
            && let State::Number(bombs) = minefield[idx - 1]
        {
            minefield[idx - 1] = State::Number(bombs + 1);
        }

        if idx < minefield.len() - 1
            && let State::Number(bombs) = minefield[idx + 1]
        {
            minefield[idx + 1] = State::Number(bombs + 1);
        }
    }
    return minefield;
}

fn solve(minefield: &[State], num_mines: usize) -> usize {
    let mut rng = rand::rng();
    let mut truths = vec![Vec::new(); 100];

    let mut idx = 0;
    while idx < truths.len() {
        let truth = generate(&mut rng, minefield.len(), num_mines);
        if is_compatible(minefield, &truth) {
            truths[idx] = truth;
            idx += 1;
        }
    }

    truths.dedup();

    return evaluate(&mut rng, minefield, truths, num_mines, 0, 10).0;
}

fn main() {
    loop {
        let mut buf = String::new();
        print!("> ");
        std::io::stdout().flush().unwrap();
        std::io::stdin().read_line(&mut buf).unwrap();

        let mut split = buf.trim().split_whitespace();
        let num_mines = str::parse(split.next().unwrap()).unwrap();
        split.next();
        let minefield: Vec<State> = split
            .map(|c| match c {
                "-" => State::Unknown,
                n => State::Number(str::parse(n).unwrap()),
            })
            .collect();

        println!("      {:>1$}", '^', solve(&minefield, num_mines) * 2 + 1);
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn explore_explores_all_cells() {
        let mut minefield = vec![
            State::Unknown,
            State::Unknown,
            State::Unknown,
            State::Unknown,
            State::Unknown,
        ];

        let truth = vec![
            State::Number(0),
            State::Number(0),
            State::Number(0),
            State::Number(0),
            State::Number(0),
        ];

        explore(&mut minefield, &truth, 2);

        assert_eq!(
            minefield,
            vec![
                State::Number(0),
                State::Number(0),
                State::Number(0),
                State::Number(0),
                State::Number(0),
            ]
        );
    }

    #[test]
    fn explore_reveals_until_first_number() {
        let mut minefield = vec![
            State::Unknown,
            State::Unknown,
            State::Unknown,
            State::Unknown,
            State::Unknown,
        ];

        let truth = vec![
            State::Number(0),
            State::Bomb,
            State::Number(1),
            State::Number(0),
            State::Number(0),
        ];

        explore(&mut minefield, &truth, 3);

        assert_eq!(
            minefield,
            vec![
                State::Unknown,
                State::Bomb,
                State::Number(1),
                State::Number(0),
                State::Number(0),
            ]
        );
    }

    #[test]
    fn explore_has_unknown() {
        let mut minefield = vec![
            State::Unknown,
            State::Unknown,
            State::Unknown,
            State::Unknown,
            State::Unknown,
        ];

        let truth = vec![
            State::Number(0),
            State::Bomb,
            State::Number(1),
            State::Number(0),
            State::Number(0),
        ];

        explore(&mut minefield, &truth, 2);

        assert_eq!(
            minefield,
            vec![
                State::Unknown,
                State::Unknown,
                State::Number(1),
                State::Unknown,
                State::Unknown,
            ]
        );
    }

    #[test]
    fn test_case_2() {
        let minefield = vec![
            State::Number(0),
            State::Number(1),
            State::Unknown,
            State::Unknown,
        ];

        assert_eq!(solve(&minefield, 1), 3);
    }

    #[test]
    fn test_case_6() {
        let minefield = vec![
            State::Unknown,
            State::Unknown,
            State::Unknown,
            State::Number(1),
            State::Unknown,
            State::Number(2),
            State::Unknown,
        ];
        assert_eq!(solve(&minefield, 3), 2);
    }

    #[test]
    fn test_case_7() {
        let minefield = vec![
            State::Unknown,
            State::Unknown,
            State::Number(2),
            State::Unknown,
            State::Unknown,
            State::Number(1),
            State::Number(1),
            State::Unknown,
            State::Number(1),
            State::Number(0),
        ];
        assert_eq!(solve(&minefield, 4), 0);
    }
}
