Source code for torchsom.utils.metrics

import warnings
from typing import Tuple, Union

import torch

from ..utils.grid import axial_distance, convert_to_axial_coords


[docs] def calculate_quantization_error( data: torch.Tensor, weights: torch.Tensor, distance_fn: callable, ) -> float: """Calculate quantization error for a SOM. Args: data (torch.Tensor): Input data tensor [batch_size, num_features] or [num_features] weights (torch.Tensor): SOM weights [row_neurons, col_neurons, num_features] distance_fn (callable): Function to compute distances between data and weights Returns: float: Average quantization error value """ # Ensure batch compatibility device = weights.device data = data.to(device) if data.dim() == 1: data = data.unsqueeze(0) # Reshape for distance calculation data_expanded = data.view(data.shape[0], 1, 1, -1) weights_expanded = weights.unsqueeze(0) # Calculate distances between each data point and all neurons distances = distance_fn(data_expanded, weights_expanded) # Calculate minimum distance for each data point min_distances = torch.min(distances.view(distances.shape[0], -1), dim=1)[0] # Return average quantization error return min_distances.mean().item()
[docs] def calculate_topographic_error( data: torch.Tensor, weights: torch.Tensor, distance_fn: callable, topology: str = "rectangular", # xx: torch.Tensor = None, # yy: torch.Tensor = None, ) -> float: """Calculate topographic error for a SOM. Args: data (torch.Tensor): Input data tensor [batch_size, num_features] or [num_features] weights (torch.Tensor): SOM weights [row_neurons, col_neurons, num_features] distance_fn (callable): Function to compute distances between data and weights topology (str, optional): Grid configuration. Defaults to "rectangular". # xx (torch.Tensor, optional): Meshgrid of x coordinates. Required for hexagonal topology. Defaults to None. # yy (torch.Tensor, optional): Meshgrid of y coordinates. Required for hexagonal topology. Defaults to None. Returns: float: Topographic error ratio """ # Ensure batch compatibility device = weights.device data = data.to(device) if data.dim() == 1: data = data.unsqueeze(0) x_dim, y_dim = weights.shape[0], weights.shape[1] if x_dim * y_dim == 1: warnings.warn("The topographic error is not defined for a 1-by-1 map.") return float("nan") # Reshape for distance calculation data_expanded = data.view(data.shape[0], 1, 1, -1) weights_expanded = weights.unsqueeze(0) # Calculate distances between each data point and all neurons distances = distance_fn(data_expanded, weights_expanded) # ! Modification to test: all the lines below could be vectorized # Get top 2 BMU indices for each sample batch_size = distances.shape[0] _, indices = torch.topk(distances.view(batch_size, -1), k=2, largest=False, dim=1) # Compute topographic error based on topology if topology == "hexagonal": # Implement hexagonal topographic error error_count = 0 for i in range(batch_size): # Convert flattened indices to 2D coordinates bmu1_row = int(torch.div(indices[i, 0], y_dim, rounding_mode="floor")) bmu1_col = int(indices[i, 0] % y_dim) bmu2_row = int(torch.div(indices[i, 1], y_dim, rounding_mode="floor")) bmu2_col = int(indices[i, 1] % y_dim) q1, r1 = convert_to_axial_coords(bmu1_row, bmu1_col) q2, r2 = convert_to_axial_coords(bmu2_row, bmu2_col) # Calculate distance in hex steps hex_distance = axial_distance(q1, r1, q2, r2) # Count as error if not neighbors (distance > 1) if hex_distance > 1: error_count += 1 return error_count / batch_size else: # Implement rectangular topographic error threshold = 1.0 # Consider only direct neighbors (4-connectivity) # Convert flattened indices to 2D row, col coordinates bmu_row = torch.div(indices, y_dim, rounding_mode="floor") bmu_col = indices % y_dim # Calculate distances between best and second-best BMUs dx = bmu_row[:, 1] - bmu_row[:, 0] dy = bmu_col[:, 1] - bmu_col[:, 0] distances = torch.sqrt(dx.float() ** 2 + dy.float() ** 2) # Units are not neighbors if distance > threshold return (distances > threshold).float().mean().item()