Some geometry updates

- Add Point.numpy_index to simplify converting Points to indexes
- Update the doc string of Point.direction_to_adjacent_point
- Add a Rect.__contains__ implementation for another Rect
- Refactor the contains implementations above into helper methods
This commit is contained in:
Eryn Wells 2023-03-05 13:35:25 -08:00
parent c488ef9c2b
commit 85928a938d

View file

@ -4,7 +4,7 @@
import math
from dataclasses import dataclass
from typing import Iterator, Optional, Tuple
from typing import Iterator, Optional, overload, Tuple
@dataclass(frozen=True)
@ -14,6 +14,11 @@ class Point:
x: int = 0
y: int = 0
@property
def numpy_index(self) -> Tuple[int, int]:
'''Convert this Point into a tuple suitable for indexing into a numpy map array'''
return (self.x, self.y)
@property
def neighbors(self) -> Iterator['Point']:
'''Iterator over the neighboring points of `self` in all eight directions.'''
@ -39,7 +44,10 @@ class Point:
return (self.x - 1 <= other.x <= self.x + 1) and (self.y - 1 <= other.y <= self.y + 1)
def direction_to_adjacent_point(self, other: 'Point') -> Optional['Vector']:
'''Given a point directly adjacent to `self`'''
'''
Given a point directly adjacent to `self`, return a Vector indicating in
which direction it is adjacent.
'''
for direction in Direction.all():
if (self + direction) != other:
continue
@ -254,14 +262,32 @@ class Rect:
return Rect(Point(self.origin.x + left, self.origin.y + top),
Size(self.size.width - right - left, self.size.height - top - bottom))
def __contains__(self, pt: Point) -> bool:
if pt.x < self.min_x or pt.x > self.max_x:
return False
@overload
def __contains__(self, other: Point) -> bool:
...
if pt.y < self.min_y or pt.y > self.max_y:
return False
@overload
def __contains__(self, other: 'Rect') -> bool:
...
return True
def __contains__(self, other: 'Point | Rect') -> bool:
if isinstance(other, Point):
return self.__contains_point(other)
if isinstance(other, Rect):
return self.__contains_rect(other)
raise TypeError(f'{self.__class__.__name__} cannot contain value of type {other.__class__.__name__}')
def __contains_point(self, pt: Point) -> bool:
return (pt.x >= self.min_x and pt.x <= self.max_x
and pt.y >= self.min_x and pt.y <= self.max_y)
def __contains_rect(self, other: 'Rect') -> bool:
return (self.min_x <= other.min_x
and self.max_x >= other.max_x
and self.min_y <= other.min_y
and self.max_y >= other.max_y)
def __iter__(self):
yield tuple(self.origin)