diff --git a/moves/src/generators/pawn.rs b/moves/src/generators/pawn.rs index 515ab16..385e1e9 100644 --- a/moves/src/generators/pawn.rs +++ b/moves/src/generators/pawn.rs @@ -27,7 +27,6 @@ enum MoveType { DoublePushes, LeftCaptures, RightCaptures, - EnPassant, } struct PromotionIterator { @@ -62,10 +61,6 @@ impl PawnMoveGenerator { let (single_pushes, double_pushes) = Self::pushes(pawns, color, empty); let (left_captures, right_captures) = Self::captures(pawns, color, enemies | en_passant); - let en_passant = en_passant & (left_captures | right_captures); - let left_captures = left_captures & !en_passant; - let right_captures = right_captures & !en_passant; - Self { color, single_pushes, @@ -137,22 +132,6 @@ impl PawnMoveGenerator { Color::White => target.neighbor(Direction::SouthWest), Color::Black => target.neighbor(Direction::NorthEast), }, - MoveType::EnPassant => match self.color { - Color::White => { - if (self.en_passant & self.left_captures).is_populated() { - target.neighbor(Direction::SouthEast) - } else { - target.neighbor(Direction::SouthWest) - } - } - Color::Black => { - if (self.en_passant & self.left_captures).is_populated() { - target.neighbor(Direction::NorthWest) - } else { - target.neighbor(Direction::NorthEast) - } - } - }, } } @@ -165,7 +144,6 @@ impl PawnMoveGenerator { MoveType::DoublePushes => self.double_pushes, MoveType::LeftCaptures => self.left_captures, MoveType::RightCaptures => self.right_captures, - MoveType::EnPassant => self.en_passant, }; self.target_iterator = next_bitboard.occupied_squares_trailing(); @@ -229,12 +207,16 @@ impl std::iter::Iterator for PawnMoveGenerator { MoveType::DoublePushes => Some(GeneratedMove { ply: Move::double_push(origin, target), }), - MoveType::LeftCaptures | MoveType::RightCaptures => Some(GeneratedMove { - ply: Move::capture(origin, target), - }), - MoveType::EnPassant => Some(GeneratedMove { - ply: Move::en_passant_capture(origin, target), - }), + MoveType::LeftCaptures | MoveType::RightCaptures => { + let target_bitboard: BitBoard = target.into(); + Some(GeneratedMove { + ply: if (target_bitboard & self.en_passant).is_populated() { + Move::en_passant_capture(origin, target) + } else { + Move::capture(origin, target) + }, + }) + } } } else if self.next_move_type().is_some() { self.next() @@ -252,8 +234,7 @@ impl MoveType { MoveType::SinglePushes => Some(MoveType::DoublePushes), MoveType::DoublePushes => Some(MoveType::LeftCaptures), MoveType::LeftCaptures => Some(MoveType::RightCaptures), - MoveType::RightCaptures => Some(MoveType::EnPassant), - MoveType::EnPassant => None, + MoveType::RightCaptures => None, } } } @@ -261,8 +242,8 @@ impl MoveType { #[cfg(test)] mod tests { use super::*; - use crate::{assert_move_list, assert_move_list_does_not_contain, ply, Move}; - use chessfriend_board::test_board; + use crate::{assert_move_list, assert_move_list_contains, assert_move_list_does_not_contain, ply, Move}; + use chessfriend_board::{fen::FromFenStr, test_board}; use chessfriend_core::{Color, Square}; use std::collections::HashSet; @@ -370,20 +351,8 @@ mod tests { #[test] fn black_d5_left_captures() { let black_captures_board = test_board!(Black Pawn on D5, White Queen on E4); - let generated_moves: HashSet = - PawnMoveGenerator::new(&black_captures_board, Some(Color::Black)).collect(); - assert_eq!( - generated_moves, - [ - GeneratedMove { - ply: Move::capture(Square::D5, Square::E4) - }, - GeneratedMove { - ply: Move::quiet(Square::D5, Square::D4) - } - ] - .into() - ); + let generated_moves = PawnMoveGenerator::new(&black_captures_board, Some(Color::Black)); + assert_move_list!(generated_moves, [ply!(D5 x E4), ply!(D5 - D4),]); } #[test] @@ -527,7 +496,7 @@ mod tests { let generated_moves = PawnMoveGenerator::new(&board, None); - assert_move_list!(generated_moves, [ply!(E5 - E6), ply!(E5 x F6 e.p.),]); + assert_move_list!(generated_moves, [ply!(E5 - E6), ply!(E5 x F6 e.p.)]); } #[test] @@ -556,4 +525,13 @@ mod tests { assert_move_list_does_not_contain!(generated_moves, [ply!(B4 x A3 e.p.)]); } + + #[test] + fn black_en_passant_check() { + let board = Board::from_fen_str("8/8/8/2k5/2pP4/8/B7/4K3 b - d3 0 3").expect("invalid fen"); + println!("{}", board.display()); + + let generated_moves: HashSet<_> = PawnMoveGenerator::new(&board, None).collect(); + assert_move_list_contains!(generated_moves, [ply!(C4 x D3 e.p.)]); + } }