Source code for torchsom.utils.grid

from typing import Tuple

import torch


[docs] def create_mesh_grid( x: int, y: int, device: str, ) -> Tuple[torch.Tensor, torch.Tensor]: """Create a mesh grid for neighborhood calculations. The function returns two 2D tensors representing the x-coordinates and y-coordinates of a grid of shape (x, y). This is useful for computing distance-based neighborhood functions in Self-Organizing Maps (SOM). Args: x (int): Number of rows (height of the grid). y (int): Number of columns (width of the grid). device (str): The device on which tensors should be allocated ('cpu' or 'cuda'). Returns: Tuple[torch.Tensor, torch.Tensor]: Two tensors (xx, yy) of shape (x, y), representing the x and y coordinates of the mesh grid. """ x_tensor, y_tensor = torch.arange(x, device=device), torch.arange( y, device=device ) # Shape: (x) and (y) x_meshgrid, y_meshgrid = torch.meshgrid( x_tensor, y_tensor, indexing="ij" ) # Create 2D meshgrid of shapes (x, y): xx contains x-coordinates, yy contains y-coordinates return ( x_meshgrid.float(), y_meshgrid.float(), ) # convert a torch(int) into a torch(float)
[docs] def adjust_meshgrid_topology( xx: torch.Tensor, yy: torch.Tensor, topology: str, ) -> Tuple[torch.Tensor, torch.Tensor]: """Adjust coordinates based on topology. Args: xx (torch.Tensor): Mesh grid of x coordinates yy (torch.Tensor): Mesh grid of y coordinates topology (str): SOM configuration, usually rectangular or hexagonal Returns: Tuple[torch.Tensor, torch.Tensor]: Adjusted x and y mesh grids for a hexagonal topology. """ if topology == "hexagonal": # Create new tensors to avoid modifying in-place adjusted_xx = xx.clone() adjusted_yy = yy.clone() adjusted_xx[::2] -= 0.5 # Adjust x-coordinates for even-indexed rows adjusted_yy *= (3.0 / 2.0) / torch.sqrt( torch.tensor(3.0) ) # Adjust all y-coordinates return adjusted_xx, adjusted_yy # Return the modified copies return xx, yy # If not hexagonal, return the original tensors
[docs] def convert_to_axial_coords( row: int, col: int, ) -> tuple[float, float]: """Convert grid coordinates to axial coordinates for hexagonal grid. Uses even-r layout where even rows are shifted left by 0.5. This matches the layout used in adjust_meshgrid_topology. Args: row (int): Grid row coordinate col (int): Grid column coordinate Returns: tuple[float, float]: Axial coordinates (q, r) """ if row % 2 == 0: q = col - 0.5 - (row // 2) else: q = col - (row // 2) r = row return q, r
[docs] def offset_to_axial_coords( row: int, col: int, ) -> tuple[float, float]: """Convert offset coordinates to axial coordinates for hexagonal grid. Alternative implementation that directly matches the mesh grid adjustment. Args: row (int): Grid row coordinate col (int): Grid column coordinate Returns: tuple[float, float]: Axial coordinates (q, r) """ # Direct conversion matching adjust_meshgrid_topology q = col - (row - (row & 1)) / 2 r = row return q, r
[docs] def axial_distance( q1: float, r1: float, q2: float, r2: float, ) -> int: """Calculate the distance between two hexes in axial coordinates. Args: q1 (float): column of first hex r1 (float): row of first hex q2 (float): column of second hex r2 (float): row of second hex Returns: int: Distance in hex steps """ # Convert axial to cube coordinates x1, y1, z1 = q1, r1, -q1 - r1 x2, y2, z2 = q2, r2, -q2 - r2 # Manhattan distance divided by 2 return int((abs(x1 - x2) + abs(y1 - y2) + abs(z1 - z2)) / 2)