Source code for torchsom.utils.grid

"""Utility functions for grid operations."""

import math

import torch

# NOTE: Coordinate conversion functions moved to hexagonal_coordinates.py to eliminate duplication and provide single source of truth.


[docs] def create_mesh_grid( x: int, y: int, device: str, ) -> tuple[torch.Tensor, torch.Tensor]: """Create a mesh grid for neighborhood calculations. The function returns two 2D tensors representing the x-coordinates and y-coordinates of a grid of shape (x, y). This is useful for computing distance-based neighborhood functions in Self-Organizing Maps (SOM). Args: x (int): Number of rows (height of the grid). y (int): Number of columns (width of the grid). device (str): The device on which tensors should be allocated ('cpu' or 'cuda'). Returns: Tuple[torch.Tensor, torch.Tensor]: Two tensors (xx, yy) of shape (x, y), representing the x and y coordinates of the mesh grid. """ x_tensor, y_tensor = torch.arange(x, device=device), torch.arange( y, device=device ) # Shape: (x) and (y) x_meshgrid, y_meshgrid = torch.meshgrid( x_tensor, y_tensor, indexing="ij" ) # Create 2D meshgrid of shapes (x, y): xx contains x-coordinates, yy contains y-coordinates return ( x_meshgrid.float(), y_meshgrid.float(), ) # convert a torch(int) into a torch(float)
[docs] def adjust_meshgrid_topology( xx: torch.Tensor, yy: torch.Tensor, topology: str, ) -> tuple[torch.Tensor, torch.Tensor]: """Adjust coordinates based on topology. Args: xx (torch.Tensor): Mesh grid of x coordinates yy (torch.Tensor): Mesh grid of y coordinates topology (str): SOM configuration, usually rectangular or hexagonal Returns: Tuple[torch.Tensor, torch.Tensor]: Adjusted x and y mesh grids for a hexagonal topology. """ if topology == "hexagonal": # Create new tensors to avoid modifying in-place adjusted_xx = xx.clone() adjusted_yy = yy.clone() """ Use even-r offset coordinate system (consistent with visualization) 1. Even rows (0, 2, 4...): no horizontal shift 2. Odd rows (1, 3, 5...): shift right by 0.5 """ adjusted_xx[1::2] += 0.5 adjusted_yy *= math.sqrt(3) / 2 return adjusted_xx, adjusted_yy return xx, yy
# def convert_to_axial_coords( # row: int, # col: int, # ) -> tuple[float, float]: # """Convert grid coordinates to axial coordinates for hexagonal grid. # Uses even-r layout where even rows are shifted left by 0.5. # This matches the layout used in adjust_meshgrid_topology. # Args: # row (int): Grid row coordinate # col (int): Grid column coordinate # Returns: # tuple[float, float]: Axial coordinates (q, r) # """ # q = col - 0.5 - row // 2 if row % 2 == 0 else col - row // 2 # r = row # return q, r # def offset_to_axial_coords( # row: int, # col: int, # ) -> tuple[float, float]: # pragma: no cover # """Convert offset coordinates to axial coordinates for hexagonal grid. # Alternative implementation that directly matches the mesh grid adjustment. # Args: # row (int): Grid row coordinate # col (int): Grid column coordinate # Returns: # tuple[float, float]: Axial coordinates (q, r) # """ # # Direct conversion matching adjust_meshgrid_topology # q = col - (row - (row & 1)) / 2 # r = row # return q, r # def axial_distance( # q1: float, # r1: float, # q2: float, # r2: float, # ) -> int: # """Calculate the distance between two hexes in axial coordinates. # Args: # q1 (float): column of first hex # r1 (float): row of first hex # q2 (float): column of second hex # r2 (float): row of second hex # Returns: # int: Distance in hex steps # """ # # Convert axial to cube coordinates # x1, y1, z1 = q1, r1, -q1 - r1 # x2, y2, z2 = q2, r2, -q2 - r2 # # Manhattan distance divided by 2 # return int((abs(x1 - x2) + abs(y1 - y2) + abs(z1 - z2)) / 2)