diff --git a/erynrl/geometry.py b/erynrl/geometry.py index b2b6cf4..97aee52 100644 --- a/erynrl/geometry.py +++ b/erynrl/geometry.py @@ -4,7 +4,7 @@ import math from dataclasses import dataclass -from typing import Iterator, Optional, Tuple +from typing import Iterator, Optional, overload, Tuple @dataclass(frozen=True) @@ -14,6 +14,11 @@ class Point: x: int = 0 y: int = 0 + @property + def numpy_index(self) -> Tuple[int, int]: + '''Convert this Point into a tuple suitable for indexing into a numpy map array''' + return (self.x, self.y) + @property def neighbors(self) -> Iterator['Point']: '''Iterator over the neighboring points of `self` in all eight directions.''' @@ -39,7 +44,10 @@ class Point: return (self.x - 1 <= other.x <= self.x + 1) and (self.y - 1 <= other.y <= self.y + 1) def direction_to_adjacent_point(self, other: 'Point') -> Optional['Vector']: - '''Given a point directly adjacent to `self`''' + ''' + Given a point directly adjacent to `self`, return a Vector indicating in + which direction it is adjacent. + ''' for direction in Direction.all(): if (self + direction) != other: continue @@ -254,14 +262,32 @@ class Rect: return Rect(Point(self.origin.x + left, self.origin.y + top), Size(self.size.width - right - left, self.size.height - top - bottom)) - def __contains__(self, pt: Point) -> bool: - if pt.x < self.min_x or pt.x > self.max_x: - return False + @overload + def __contains__(self, other: Point) -> bool: + ... - if pt.y < self.min_y or pt.y > self.max_y: - return False + @overload + def __contains__(self, other: 'Rect') -> bool: + ... - return True + def __contains__(self, other: 'Point | Rect') -> bool: + if isinstance(other, Point): + return self.__contains_point(other) + + if isinstance(other, Rect): + return self.__contains_rect(other) + + raise TypeError(f'{self.__class__.__name__} cannot contain value of type {other.__class__.__name__}') + + def __contains_point(self, pt: Point) -> bool: + return (pt.x >= self.min_x and pt.x <= self.max_x + and pt.y >= self.min_x and pt.y <= self.max_y) + + def __contains_rect(self, other: 'Rect') -> bool: + return (self.min_x <= other.min_x + and self.max_x >= other.max_x + and self.min_y <= other.min_y + and self.max_y >= other.max_y) def __iter__(self): yield tuple(self.origin)