diff --git a/board/src/lib.rs b/board/src/lib.rs index 61b699d..44d142c 100644 --- a/board/src/lib.rs +++ b/board/src/lib.rs @@ -11,4 +11,4 @@ mod tests; pub use moves::Move; pub use position::Position; -pub use square::Square; +pub use square::{File, Rank, Square}; diff --git a/board/src/square.rs b/board/src/square.rs index bba04e3..8dfb5b7 100644 --- a/board/src/square.rs +++ b/board/src/square.rs @@ -14,110 +14,198 @@ pub enum Direction { NorthEast, } +#[derive(Debug)] +pub struct ParseFileError; + #[derive(Debug)] pub struct ParseSquareError; #[derive(Debug)] pub struct SquareOutOfBoundsError; -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Square { - rank: u8, - file: u8, - index: u8, +macro_rules! coordinate_enum { + ($name: ident, $($variant:ident),*) => { + #[repr(u8)] + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] + pub enum $name { + $($variant), * + } + + impl $name { + pub const NUM: usize = [$(Self::$variant), *].len(); + pub const ALL: [Self; Self::NUM] = [$(Self::$variant), *]; + + #[inline] + pub(crate) fn from_index(index: usize) -> Self { + assert!( + index < Self::NUM, + "Index {} out of bounds for {}.", + index, + stringify!($name) + ); + Self::try_index(index).unwrap() + } + + pub fn try_index(index: usize) -> Option { + $( + #[allow(non_upper_case_globals)] + const $variant: usize = $name::$variant as usize; + )* + + match index { + $($variant => Some($name::$variant),)* + _ => None, + } + } + } + } +} + +#[rustfmt::skip] +coordinate_enum!(Rank, + One, Two, Three, Four, Five, Six, Seven, Eight +); + +#[rustfmt::skip] +coordinate_enum!(File, + A, B, C, D, E, F, G, H +); + +#[rustfmt::skip] +coordinate_enum!(Square, + A1, B1, C1, D1, E1, F1, G1, H1, + A2, B2, C2, D2, E2, F2, G2, H2, + A3, B3, C3, D3, E3, F3, G3, H3, + A4, B4, C4, D4, E4, F4, G4, H4, + A5, B5, C5, D5, E5, F5, G5, H5, + A6, B6, C6, D6, E6, F6, G6, H6, + A7, B7, C7, D7, E7, F7, G7, H7, + A8, B8, C8, D8, E8, F8, G8, H8 +); + +impl Into for File { + fn into(self) -> char { + ('a' as u8 + self as u8) as char + } +} + +impl TryFrom for File { + type Error = ParseFileError; + + fn try_from(value: char) -> Result { + let lowercase_value = value.to_ascii_lowercase(); + for file in File::ALL.iter() { + if lowercase_value == (*file).into() { + return Ok(*file); + } + } + + Err(ParseFileError) + } +} + +impl fmt::Display for File { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", Into::::into(*self).to_uppercase()) + } +} + +impl Into for Rank { + fn into(self) -> char { + ('1' as u8 + self as u8) as char + } +} + +impl TryFrom for Rank { + type Error = ParseFileError; + + fn try_from(value: char) -> Result { + let lowercase_value = value.to_ascii_lowercase(); + for rank in Self::ALL.iter().cloned() { + if lowercase_value == rank.into() { + return Ok(rank); + } + } + + Err(ParseFileError) + } +} + +impl fmt::Display for Rank { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", Into::::into(*self)) + } } impl Square { - pub fn from_rank_file(rank: u8, file: u8) -> Result { - if rank >= 8 || file >= 8 { - return Err(SquareOutOfBoundsError); - } - - Ok(Square { - rank, - file, - index: rank * 8 + file, - }) + #[inline] + pub fn from_file_rank(file: File, rank: Rank) -> Square { + Self::from_index((rank as usize) << 3 | file as usize) } + #[inline] + pub fn file(self) -> File { + File::from_index(self as usize & 0b000111) + } + + #[inline] + pub fn rank(self) -> Rank { + Rank::from_index(self as usize >> 3) + } +} + +impl Square { pub fn from_algebraic_str(s: &str) -> Result { s.parse() } - pub fn rank_file(&self) -> (u8, u8) { - (self.rank, self.file) - } - - pub fn neighbor(&self, direction: Direction) -> Option { + pub fn neighbor(self, direction: Direction) -> Option { match direction { - Direction::North => Square::from_index(self.index + 8), + Direction::North => Square::try_index(self as usize + 8), Direction::NorthWest => { - if self.rank < 7 { - Square::from_index(self.index + 7) + if self.rank() != Rank::Eight { + Square::try_index(self as usize + 7) } else { - Err(SquareOutOfBoundsError) + None } } Direction::West => { - if self.file > 0 { - Square::from_index(self.index - 1) + if self.file() != File::A { + Square::try_index(self as usize - 1) } else { - Err(SquareOutOfBoundsError) + None } } Direction::SouthWest => { - if self.rank > 0 { - Square::from_index(self.index - 9) + if self.rank() != Rank::One { + Square::try_index(self as usize - 9) } else { - Err(SquareOutOfBoundsError) + None } } Direction::South => { - if self.rank > 0 { - Square::from_index(self.index - 8) + if self.rank() != Rank::One { + Square::try_index(self as usize - 8) } else { - Err(SquareOutOfBoundsError) + None } } Direction::SouthEast => { - if self.rank > 0 { - Square::from_index(self.index - 7) + if self.rank() != Rank::One { + Square::try_index(self as usize - 7) } else { - Err(SquareOutOfBoundsError) + None } } Direction::East => { - if self.file < 7 { - Square::from_index(self.index + 1) + if self.file() != File::H { + Square::try_index(self as usize + 1) } else { - Err(SquareOutOfBoundsError) + None } } - Direction::NorthEast => Square::from_index(self.index + 9), + Direction::NorthEast => Square::try_index(self as usize + 9), } - .ok() - } -} - -impl Square { - pub(crate) fn from_index(index: u8) -> Result { - if index >= 64 { - return Err(SquareOutOfBoundsError); - } - - Ok(Square::from_index_unsafe(index)) - } - - pub(crate) fn from_index_unsafe(index: u8) -> Square { - Square { - rank: index / 8, - file: index % 8, - index: index, - } - } - - pub(crate) fn index(&self) -> u8 { - self.index } } @@ -125,47 +213,34 @@ impl FromStr for Square { type Err = ParseSquareError; fn from_str(s: &str) -> Result { - if !s.is_ascii() || s.len() != 2 { - return Err(ParseSquareError); - } - let mut chars = s.chars(); - let file_char = chars.next().unwrap().to_ascii_lowercase(); - if !file_char.is_ascii_lowercase() { - return Err(ParseSquareError); - } - - let file = (file_char as u8) - ('a' as u8); - if file >= 8 { - return Err(ParseSquareError); - } - - let converted_rank_digit = chars + let file: File = chars .next() - .unwrap() - .to_digit(10) - .and_then(|x| if x >= 1 && x <= 8 { Some(x) } else { None }) + .and_then(|c| c.try_into().ok()) .ok_or(ParseSquareError)?; - let rank = u8::try_from(converted_rank_digit).map_err(|_| ParseSquareError)? - 1; - Ok(Square { - rank, - file, - index: rank * 8 + file, - }) - } -} + let rank: Rank = chars + .next() + .and_then(|c| c.try_into().ok()) + .ok_or(ParseSquareError)?; -impl Into for Square { - fn into(self) -> BitBoard { - BitBoard::new(1 << self.index) + if !chars.next().is_none() { + return Err(ParseSquareError); + } + + Ok(Square::from_file_rank(file, rank)) } } impl fmt::Display for Square { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}{}", ('a' as u8 + self.file) as char, self.rank + 1) + write!( + f, + "{}{}", + ('a' as u8 + self.file() as u8) as char, + self.rank() as usize + 1 + ) } } @@ -175,17 +250,17 @@ mod tests { #[test] fn good_algebraic_input() { - let sq1 = Square::from_algebraic_str("a4").expect("Failed to parse 'a4' square"); - assert_eq!(sq1.file, 0); - assert_eq!(sq1.rank, 3); + let sq = Square::from_algebraic_str("a4").expect("Failed to parse 'a4' square"); + assert_eq!(sq.file(), File::A); + assert_eq!(sq.rank(), Rank::Four); - let sq2 = Square::from_algebraic_str("B8").expect("Failed to parse 'B8' square"); - assert_eq!(sq2.file, 1); - assert_eq!(sq2.rank, 7); + let sq = Square::from_algebraic_str("B8").expect("Failed to parse 'B8' square"); + assert_eq!(sq.file(), File::B); + assert_eq!(sq.rank(), Rank::Eight); - let sq3 = Square::from_algebraic_str("e4").expect("Failed to parse 'B8' square"); - assert_eq!(sq3.rank, 3, "Expected rank of e4 to be 3"); - assert_eq!(sq3.file, 4, "Expected file of e4 to be 4"); + let sq = Square::from_algebraic_str("e4").expect("Failed to parse 'B8' square"); + assert_eq!(sq.file(), File::E); + assert_eq!(sq.rank(), Rank::Four); } #[test] @@ -200,81 +275,50 @@ mod tests { #[test] fn from_index() { - let sq1 = Square::from_index(4).expect("Unable to get Square from index"); - assert_eq!(sq1.rank, 0); - assert_eq!(sq1.file, 4); + let sq = Square::try_index(4).expect("Unable to get Square from index"); + assert_eq!(sq.file(), File::E); + assert_eq!(sq.rank(), Rank::One); - let sq1 = Square::from_index(28).expect("Unable to get Square from index"); - assert_eq!(sq1.rank, 3); - assert_eq!(sq1.file, 4); + let sq = Square::try_index(28).expect("Unable to get Square from index"); + assert_eq!(sq.file(), File::E); + assert_eq!(sq.rank(), Rank::Four); } #[test] fn valid_neighbors() { - let sq = Square::from_index_unsafe(28); + let sq = Square::E4; - assert_eq!( - sq.neighbor(Direction::North), - Some(Square::from_index_unsafe(36)) - ); - - assert_eq!( - sq.neighbor(Direction::NorthEast), - Some(Square::from_index_unsafe(37)) - ); - - assert_eq!( - sq.neighbor(Direction::East), - Some(Square::from_index_unsafe(29)) - ); - - assert_eq!( - sq.neighbor(Direction::SouthEast), - Some(Square::from_index_unsafe(21)) - ); - - assert_eq!( - sq.neighbor(Direction::South), - Some(Square::from_index_unsafe(20)) - ); - - assert_eq!( - sq.neighbor(Direction::SouthWest), - Some(Square::from_index_unsafe(19)) - ); - - assert_eq!( - sq.neighbor(Direction::West), - Some(Square::from_index_unsafe(27)) - ); - - assert_eq!( - sq.neighbor(Direction::NorthWest), - Some(Square::from_index_unsafe(35)) - ); + assert_eq!(sq.neighbor(Direction::North), Some(Square::E5)); + assert_eq!(sq.neighbor(Direction::NorthEast), Some(Square::F5)); + assert_eq!(sq.neighbor(Direction::East), Some(Square::F4)); + assert_eq!(sq.neighbor(Direction::SouthEast), Some(Square::F3)); + assert_eq!(sq.neighbor(Direction::South), Some(Square::E3)); + assert_eq!(sq.neighbor(Direction::SouthWest), Some(Square::D3)); + assert_eq!(sq.neighbor(Direction::West), Some(Square::D4)); + assert_eq!(sq.neighbor(Direction::NorthWest), Some(Square::D5)); } #[test] fn invalid_neighbors() { - let square0 = Square::from_index_unsafe(0); - assert!(square0.neighbor(Direction::West).is_none()); - assert!(square0.neighbor(Direction::SouthWest).is_none()); - assert!(square0.neighbor(Direction::South).is_none()); + let sq = Square::A1; + assert!(sq.neighbor(Direction::West).is_none()); + assert!(sq.neighbor(Direction::SouthWest).is_none()); + assert!(sq.neighbor(Direction::South).is_none()); - let square7 = Square::from_index_unsafe(7); - assert!(square7.neighbor(Direction::East).is_none()); - assert!(square7.neighbor(Direction::SouthEast).is_none()); - assert!(square7.neighbor(Direction::South).is_none()); + let sq = Square::H1; + assert!(sq.neighbor(Direction::East).is_none()); + assert!(sq.neighbor(Direction::SouthEast).is_none()); + assert!(sq.neighbor(Direction::South).is_none()); - let square56 = Square::from_index_unsafe(56); - assert!(square56.neighbor(Direction::North).is_none()); - assert!(square56.neighbor(Direction::NorthWest).is_none()); - assert!(square56.neighbor(Direction::West).is_none()); + let sq = Square::A8; + assert!(sq.neighbor(Direction::North).is_none()); + assert!(sq.neighbor(Direction::NorthWest).is_none()); + assert!(sq.neighbor(Direction::West).is_none()); - let square63 = Square::from_index_unsafe(63); - assert!(square63.neighbor(Direction::North).is_none()); - assert!(square63.neighbor(Direction::NorthEast).is_none()); - assert!(square63.neighbor(Direction::East).is_none()); + let sq = Square::H8; + assert!(sq.neighbor(Direction::North).is_none()); + assert!(sq.neighbor(Direction::NorthEast).is_none()); + assert!(sq.neighbor(Direction::East).is_none()); } #[test]