Source code for torchsom.core.som

import heapq
import random
import warnings
from collections import Counter, defaultdict
from typing import Any, Dict, List, Tuple, Union

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from ..utils.decay import DECAY_FUNCTIONS
from ..utils.distances import DISTANCE_FUNCTIONS
from ..utils.grid import adjust_meshgrid_topology, create_mesh_grid
from ..utils.initialization import initialize_weights
from ..utils.metrics import calculate_quantization_error, calculate_topographic_error
from ..utils.neighborhood import NEIGHBORHOOD_FUNCTIONS
from ..utils.topology import get_all_neighbors_up_to_order
from .base_som import BaseSOM


[docs] class SOM(BaseSOM): """PyTorch implementation of Self Organizing Maps using batch learning. Args: BaseSOM: Abstract base class for SOM variants """ def __init__( self, x: int, y: int, num_features: int, epochs: int = 10, batch_size: int = 5, sigma: float = 1.0, learning_rate: float = 0.5, neighborhood_order: int = 1, topology: str = "rectangular", lr_decay_function: str = "asymptotic_decay", sigma_decay_function: str = "asymptotic_decay", neighborhood_function: str = "gaussian", distance_function: str = "euclidean", initialization_mode: str = "random", device: str = "cuda" if torch.cuda.is_available() else "cpu", random_seed: int = 42, ): """Initialize the SOM. Args: x (int): Number of rows y (int): Number of cols num_features (int): Number of input features epochs (int, optional): Number of epochs to train. Defaults to 10. batch_size (int, optional): Number of samples to be considered at each epoch (training). Defaults to 5. sigma (float, optional): Width of the neighborhood, so standard deviation. It controls the spread of the update influence. Defaults to 1.0. learning_rate (float, optional): Strength of the weights updates. Defaults to 0.5. neighborhood_order (int, optional): Number of neighbors to consider for the distance calculation. Defaults to 1. topology (str, optional): Grid configuration. Defaults to "rectangular". lr_decay_function (str, optional): Function to adjust (decrease) the learning rate at each epoch (training). Defaults to "asymptotic_decay". sigma_decay_function (str, optional): Function to adjust (decrease) the sigma at each epoch (training). Defaults to "asymptotic_decay". neighborhood_function (str, optional): Function to update the weights at each epoch (training). Defaults to "gaussian". distance_function (str, optional): Function to compute the distance between grid weights and input data. Defaults to "euclidean". initialization_mode (str, optional): Method to initialize SOM weights. Defaults to "random". device (str, optional): Allocate tensors on CPU or GPU. Defaults to "cuda" if available, else "cpu". random_seed (int, optional): Ensure reproducibility. Defaults to 42. Raises: ValueError: Ensure valid topology """ super(SOM, self).__init__() # Validate parameters if sigma > torch.sqrt(torch.tensor(float(x * x + y * y))): warnings.warn( "Warning: sigma might be too high for the dimension of the map." ) if topology not in ["hexagonal", "rectangular"]: raise ValueError("Only hexagonal and rectangular topologies are supported") # Input parameters self.x = x self.y = y self.num_features = num_features self.sigma = sigma self.learning_rate = learning_rate self.epochs = epochs self.batch_size = batch_size self.device = device self.topology = topology self.random_seed = random_seed self.neighborhood_order = neighborhood_order self.distance_fn_name = distance_function self.initialization_mode = initialization_mode self.distance_fn = DISTANCE_FUNCTIONS[distance_function] self.lr_decay_fn = DECAY_FUNCTIONS[lr_decay_function] self.sigma_decay_fn = DECAY_FUNCTIONS[sigma_decay_function] # Set up x and y mesh grids, adjust them based on the topology x_meshgrid, y_meshgrid = create_mesh_grid(x, y, device) self.xx, self.yy = adjust_meshgrid_topology(x_meshgrid, y_meshgrid, topology) # Set up neighborhood function self.neighborhood_fn = lambda win_neuron, sigma: NEIGHBORHOOD_FUNCTIONS[ neighborhood_function ](self.xx, self.yy, win_neuron, sigma) # Ensure reproducibility torch.manual_seed(random_seed) # Initialize & normalize weights weights = 2 * torch.randn(x, y, num_features, device=device) - 1 normalized_weights = weights / torch.norm(weights, dim=-1, keepdim=True) self.weights = nn.Parameter(normalized_weights, requires_grad=False) def _update_weights( self, data: torch.Tensor, bmus: Union[Tuple[int, int], torch.Tensor], learning_rate: float, sigma: float, ) -> None: """Update weights using neighborhood function. Handles both single samples and batches. Args: data (torch.Tensor): Input tensor of shape [num_features] or [batch_size, num_features] bmus (Union[Tuple[int, int], torch.Tensor]): BMU coordinates as tuple (single) or tensor (batch) learning_rate (float): Current learning rate sigma (float): Current sigma value """ # Single sample if isinstance(bmus, tuple): # Calculate neighborhood contributions for the BMU and reshape for broadcasting neighborhood = self.neighborhood_fn(bmus, sigma) neighborhood = neighborhood.view(self.x, self.y, 1) # Calculate the update for the single sample update = learning_rate * neighborhood * (data - self.weights) # Update the weights self.weights.data += update # Batch samples else: # Calculate neighborhood contributions for each BMU in batch batch_size = data.shape[0] neighborhoods = torch.stack( [ self.neighborhood_fn((row.item(), col.item()), sigma) for row, col in bmus ] ) # [batch_size, row_neurons, col_neurons] # ! Modification to test # # Vectorised: build a tensor of BMU coordinates and compute in one shot # coords = torch.stack([bmus[:, 0], bmus[:, 1]], dim=1).to(torch.long) # neighborhoods = self.neighborhood_fn(coords, sigma) # update neighborhood_fn to accept batched coords # [batch_size, row_neurons, col_neurons] # Reshape for broadcasting neighborhoods = neighborhoods.view(batch_size, self.x, self.y, 1) data_expanded = data.view(batch_size, 1, 1, self.num_features) # Calculate the updates for all samples updates = learning_rate * neighborhoods * (data_expanded - self.weights) # Average updates across batch and apply to weights self.weights.data += updates.mean(dim=0) def _calculate_distances_to_neurons( self, data: torch.Tensor, ) -> torch.Tensor: """Calculate distances between input data and all neurons' weights. Handles both single samples and batches. Args: data: Input tensor of shape [num_features] if single or [batch_size, num_features] if batch Returns: Distances tensor of shape [row_neurons, col_neurons] or [batch_size, row_neurons, col_neurons] """ # Ensure device and batch compatibility data = data.to(self.device) if data.dim() == 1: data = data.unsqueeze(0) data_batch_size = data.shape[0] # Reshape both data and weights for broadcasting when calculating the distance data_expanded = data.view( data_batch_size, 1, 1, self.num_features ) # From [batch_size, num_features] to [batch_size, 1, 1, num_features] weights_expanded = self.weights.unsqueeze( 0 ) # [1, row_neurons, col_neurons, num_features] # Compute distances for the whole batch [batch_size, row_neurons, col_neurons] distances = self.distance_fn(data_expanded, weights_expanded) # Single sample case - remove batch dimension if data_batch_size == 1: distances = distances.squeeze(0) return distances
[docs] def identify_bmus( self, data: torch.Tensor, ) -> torch.Tensor: """Find BMUs for input data. Handles both single samples and batches. It requires a data on the GPU if available for calculations with SOM's weights on GPU's too. Args: data (torch.Tensor): Input tensor of shape [num_features] or [batch_size, num_features] Returns: torch.Tensor: For single sample: Tensor of shape [2] with [row, col]. For batch: Tensor of shape [batch_size, 2] with [row, col] pairs """ distances = self._calculate_distances_to_neurons(data) # Unique sample [row_neurons, col_neurons] if distances.dim() == 2: index = torch.argmin( distances.view(-1) ) # From 2D tensor [m,n] to 1D tensor [m*n] then retrieve the index of the bmu with the smallest distance row, col = torch.unravel_index( index, (self.x, self.y), ) # Convert the index to 2D coordinates coords = torch.stack([row, col], dim=0).to(data.device) return coords # Batch samples [batch_size, row_neurons, col_neurons] else: indices = torch.argmin( distances.view(distances.shape[0], -1), dim=1 ) # From 3D tensor [batch_size, m, n] to 2D tensor [batch_size, m*n] then retrieve the index of the bmu with the smallest distance for all samples return torch.stack( [torch.div(indices, self.y, rounding_mode="floor"), indices % self.y], dim=1, )
[docs] def quantization_error( self, data: torch.Tensor, ) -> float: """Calculate quantization error. Args: data (torch.Tensor): input data tensor [batch_size, num_features] or [num_features] Returns: float: Average quantization error value """ # Ensure device and batch compatibility data = data.to(self.device) if data.dim() == 1: data = data.unsqueeze(0) # Use the utility function for calculation return calculate_quantization_error(data, self.weights, self.distance_fn)
[docs] def topographic_error( self, data: torch.Tensor, ) -> float: """Calculate topographic error with batch support Args: data (torch.Tensor): input data tensor [batch_size, num_features] or [num_features] Returns: float: Topographic error ratio """ # Ensure device and batch compatibility data = data.to(self.device) if data.dim() == 1: data = data.unsqueeze(0) return calculate_topographic_error( data, self.weights, self.distance_fn, self.topology )
[docs] def initialize_weights( self, data: torch.Tensor, mode: str = None, ) -> None: """Data should be normalized before initialization. Initialize weights using 1. Random samples from input data. 2. PCA components to make the training process converge faster. Args: data (torch.Tensor): input data tensor [batch_size, num_features] mode (str, optional): selection of the method to init the weights. Defaults to None. Raises: ValueError: Ensure neurons' weights and input data have the same number of features RuntimeError: If random initialization takes too long ValueError: Requires at least 2 features for PCA ValueError: Requires more than one sample to perform PCA ValueError: Ensure an appropriate method for initialization """ data = data.to(self.device) if data.shape[1] != self.num_features: raise ValueError( f"Input data dimension ({data.shape[1]}) and weights dimension ({self.num_features}) don't match" ) if mode is None: mode = self.initialization_mode # Use utility function for initialization new_weights = initialize_weights( self.weights.data, data, mode, self.topology, self.device ) self.weights.data = new_weights
[docs] def fit( self, data: torch.Tensor, ) -> Tuple[List[float], List[float]]: """Train the SOM using batches and track errors. Args: data (torch.Tensor): input data tensor [batch_size, num_features] Returns: Tuple[List[float], List[float]]: Quantization and topographic errors [epoch] """ # data = data.to(self.device) dataset = TensorDataset(data) dataloader = DataLoader( dataset, batch_size=self.batch_size, shuffle=True, pin_memory=False ) q_errors = [] t_errors = [] for epoch in tqdm( range(self.epochs), desc="Training SOM", unit="epoch", disable=False, ): # Update learning parameters through decay function (schedulers) lr = self.lr_decay_fn(self.learning_rate, t=epoch, max_iter=self.epochs) sigma = self.sigma_decay_fn(self.sigma, t=epoch, max_iter=self.epochs) epoch_q_errors = [] epoch_t_errors = [] for batch in dataloader: batch_data = batch[0].to(self.device) # Get BMUs for all data points at once [batch_size, 2] with torch.no_grad(): bmus = self.identify_bmus(batch_data) # Update the weights of each neuron self._update_weights(batch_data, bmus, lr, sigma) # Calculate both errors at each batch and store them with torch.no_grad(): epoch_q_errors.append(self.quantization_error(batch_data)) epoch_t_errors.append(self.topographic_error(batch_data)) # Clean GPU memory torch.cuda.empty_cache() # Compute both average errors at each epoch and store them q_errors.append(torch.tensor(epoch_q_errors).mean().item()) t_errors.append(100 * torch.tensor(epoch_t_errors).mean().item()) return q_errors, t_errors
[docs] def collect_samples( self, query_sample: torch.Tensor, historical_samples: torch.Tensor, historical_outputs: torch.Tensor, min_buffer_threshold: int = 50, bmus_idx_map: Dict[Tuple[int, int], List[int]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Collect historical samples similar to the query sample using SOM projection. Args: query_sample (torch.Tensor): The query data point [num_features] historical_samples (torch.Tensor): Historical input data [num_samples, num_features] historical_outputs (torch.Tensor): Historical output values [num_samples] min_buffer_threshold (int, optional): Minimum number of samples to collect. Defaults to 50. Returns: Tuple[torch.Tensor, torch.Tensor]: (historical_data_buffer, historical_output_buffer) """ # Ensure device compatibility query_sample = query_sample.to(self.device) # Find BMU for the query sample with torch.no_grad(): bmu_pos = self.identify_bmus(query_sample) bmu_tuple = (int(bmu_pos[0].item()), int(bmu_pos[1].item())) # Collect samples indices from the query's BMU if any exist # ! DUE TO CHANGES IN TORCHSOM, bmus_idx_map is on cpu now even with gpus sample_indices = [] if bmu_tuple in bmus_idx_map and len(bmus_idx_map[bmu_tuple]) > 0: sample_indices.extend(bmus_idx_map[bmu_tuple]) # Keep track of the neurons used to build the historical buffers visited_neurons = {bmu_tuple} # Get all neighbor offsets based on topology all_offsets = get_all_neighbors_up_to_order( topology=self.topology, max_order=self.neighborhood_order, ) # Handle topology-specific offset processing if self.topology == "rectangular": for dx, dy in all_offsets: neighbor_pos = ( int(bmu_pos[0].item() + dx), int(bmu_pos[1].item() + dy), ) if neighbor_pos in visited_neurons: continue visited_neurons.add(neighbor_pos) # Check if the neighbor is 1) within SOM bounds, and 2) activated if ( 0 <= neighbor_pos[0] < self.x and 0 <= neighbor_pos[1] < self.y and neighbor_pos in bmus_idx_map ): sample_indices.extend(bmus_idx_map[neighbor_pos]) elif self.topology == "hexagonal": bmu_row = int(bmu_pos[0].item()) row_type = "even" if bmu_row % 2 == 0 else "odd" for dx, dy in all_offsets[row_type]: neighbor_pos = ( int(bmu_pos[0].item() + dx), int(bmu_pos[1].item() + dy), ) if neighbor_pos in visited_neurons: continue visited_neurons.add(neighbor_pos) # Check if the neighbor is 1) within SOM bounds, and 2) activated if ( 0 <= neighbor_pos[0] < self.x and 0 <= neighbor_pos[1] < self.y and neighbor_pos in bmus_idx_map ): sample_indices.extend(bmus_idx_map[neighbor_pos]) """ Secondly, ensure we have enough training samples. This time, explore neighbors that are close in terms of distance in the weights space. """ if len(sample_indices) <= min_buffer_threshold: # Calculate distances from BMU weights to all neurons with torch.no_grad(): neurons_distance_map = self._calculate_distances_to_neurons( data=self.weights.data[bmu_pos[0], bmu_pos[1]] ) # Build min heap of (distance, position) for unvisited neurons with samples distance_min_heap = [] for row in range(self.x): for col in range(self.y): neuron_pos = (row, col) if neuron_pos in visited_neurons: continue if neuron_pos in bmus_idx_map and len(bmus_idx_map[neuron_pos]) > 0: distance = neurons_distance_map[row, col].item() heapq.heappush(distance_min_heap, (distance, neuron_pos)) # Add samples until threshold is reached while distance_min_heap and len(sample_indices) <= min_buffer_threshold: _, closest_neuron = heapq.heappop(distance_min_heap) visited_neurons.add(closest_neuron) if closest_neuron in bmus_idx_map: sample_indices.extend(bmus_idx_map[closest_neuron]) historical_data_buffer = historical_samples[sample_indices] historical_output_buffer = historical_outputs[sample_indices].view(-1, 1) return historical_data_buffer, historical_output_buffer
[docs] def build_hit_map( self, data: torch.Tensor, batch_size: int = 1024, ) -> torch.Tensor: """Returns a matrix where element i,j is the number of times that neuron i,j has been the winner. It processes the data in batches to save memory. The hit map is built on CPU, but the calculations are done on GPU if available. Args: data (torch.Tensor): input data tensor [batch_size, num_features] batch_size (int, optional): Size of batches to process. Defaults to 1024. Returns: torch.Tensor: Matrix indicating the number of times each neuron has been identified as bmu. """ # Ensure batch compatibility if data.dim() == 1: data = data.unsqueeze(0) # Initialize hit map on CPU hit_map = torch.zeros((self.x, self.y)) # Process data in batches to save GPU memory num_samples = data.shape[0] num_batches = (num_samples + batch_size - 1) // batch_size for batch_idx in range(num_batches): # Retrieve corresponding batches and move them to device start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, num_samples) current_batch_size = end_idx - start_idx batch_data = data[start_idx:end_idx].to(self.device) # Get BMUs for this batch batch_bmus = self.identify_bmus(batch_data) # Handle special case when batch has only one sample if current_batch_size == 1: # If only one sample, ensure batch_bmus is properly shaped if batch_bmus.dim() == 1: batch_bmus = batch_bmus.unsqueeze(0) row, col = batch_bmus[0] hit_map[row.item(), col.item()] += 1 # Otherwise, process multiple samples normally else: # Update and store hit map on CPU for row, col in batch_bmus: hit_map[row.item(), col.item()] += 1 # Clean up GPU memory del batch_data, batch_bmus if torch.cuda.is_available(): torch.cuda.empty_cache() return hit_map
[docs] def build_distance_map( self, scaling: str = "sum", distance_metric: str = None, neighborhood_order: int = None, ) -> torch.Tensor: """Computes the distance map of each neuron with its neighbors. The distance map represents the normalized sum or mean of distances between a neuron's weight vector and its neighboring neurons. Args: scaling (str, optional): Defaults to "sum". If 'mean', each cell is normalized by the average neighbor distance. If 'sum', normalization is done by the sum of distances. distance_metric (str, optional): Name of the method to calculate the distance. Defaults to None. neighborhood_order (int, optional): Indicate the neighbors to consider for the distance calculation. Defaults to None. Raises: ValueError: If an invalid scaling option is provided. ValueError: If an invalid distance metric is provided. Returns: torch.Tensor: Normalized distance map [row_neurons, col_neurons] """ if scaling not in ["sum", "mean"]: raise ValueError( f'scaling should be either "sum" or "mean" ({scaling} is not valid)' ) # Use instance neighborhood_order if not specified if neighborhood_order is None: neighborhood_order = self.neighborhood_order # Indicate the distance function to use if distance_metric is None: distance_fn = self.distance_fn else: if distance_metric not in DISTANCE_FUNCTIONS: raise ValueError(f"Unsupported distance metric: {distance_metric}") distance_fn = DISTANCE_FUNCTIONS[distance_metric] # Get all neighbor offsets based on topology all_offsets = get_all_neighbors_up_to_order( topology=self.topology, max_order=neighborhood_order, ) # Calculate maximum possible neighbors for tensor initialization if self.topology == "hexagonal": # For hexagonal, we need to handle even/odd rows separately max_neighbors = max(len(all_offsets["even"]), len(all_offsets["odd"])) else: # For rectangular topology max_neighbors = len(all_offsets) # Initialize distance map distance_matrix = torch.full( (self.weights.shape[0], self.weights.shape[1], max_neighbors), float("nan"), device=self.device, ) # Compute distances for each neuron for row in range(self.weights.shape[0]): for col in range(self.weights.shape[1]): current_neuron = self.weights[row, col] neighbor_idx = 0 # Handle topology-specific neighbor processing if self.topology == "hexagonal": # Use appropriate offsets based on row parity (even/odd) row_offsets = ( all_offsets["even"] if row % 2 == 0 else all_offsets["odd"] ) for row_offset, col_offset in row_offsets: neighbor_row = row + row_offset neighbor_col = col + col_offset # Ensure neighbor is within bounds to compute the distance if ( 0 <= neighbor_row < self.weights.shape[0] and 0 <= neighbor_col < self.weights.shape[1] ): neighbor_neuron = self.weights[neighbor_row, neighbor_col] """ Reshape weights to ensure batch compatibility with distance function => shape [a,b] becomes [1,a,b] after unsqueeze(0) Each neuron has a shape of [num_features] so they become [1,num_features] and then [1,1,num_features] Finally, distance function need to be squeezed because it returns [batch_size, 1] but there is only one sample, so let's just retrieve the scalar """ solo_batch_current_neuron = current_neuron.unsqueeze( 0 ).unsqueeze(0) solo_batch_neighbor_neuron = neighbor_neuron.unsqueeze( 0 ).unsqueeze(0) # Calculate and store the distance distance_matrix[row, col, neighbor_idx] = distance_fn( solo_batch_current_neuron, solo_batch_neighbor_neuron, ).squeeze() neighbor_idx += 1 else: # Rectangular topology - process all offsets directly for row_offset, col_offset in all_offsets: neighbor_row = row + row_offset neighbor_col = col + col_offset # Ensure neighbor is within bounds to compute the distance if ( 0 <= neighbor_row < self.weights.shape[0] and 0 <= neighbor_col < self.weights.shape[1] ): neighbor_neuron = self.weights[neighbor_row, neighbor_col] """ Reshape weights to ensure batch compatibility with distance function => shape [a,b] becomes [1,a,b] after unsqueeze(0) Each neuron has a shape of [num_features] so they become [1,num_features] and then [1,1,num_features] Finally, distance function need to be squeezed because it returns [batch_size, 1] but there is only one sample, so let's just retrieve the scalar """ solo_batch_current_neuron = current_neuron.unsqueeze( 0 ).unsqueeze(0) solo_batch_neighbor_neuron = neighbor_neuron.unsqueeze( 0 ).unsqueeze(0) # Calculate and store the distance distance_matrix[row, col, neighbor_idx] = distance_fn( solo_batch_current_neuron, solo_batch_neighbor_neuron, ).squeeze() neighbor_idx += 1 """ Aggregate distances (either sum or mean). Each neuron has approximately k distances based on the topology (and bounds). Compute the aggregation on the last dimension where all the ,neighbor distances are computed. Both torch methods ignore NaNs. """ if scaling == "mean": distance_matrix = torch.nanmean(distance_matrix, dim=2) else: distance_matrix = torch.nansum(distance_matrix, dim=2) # Normalize the distance map max_distance = torch.max( distance_matrix.masked_fill(torch.isnan(distance_matrix), float("-inf")) ) # Replace NaNs with -inf to be ignored by max() return distance_matrix / max_distance if max_distance > 0 else distance_matrix
[docs] def build_bmus_data_map( self, data: torch.Tensor, return_indices: bool = False, batch_size: int = 1024, ) -> Dict[Tuple[int, int], Any]: """Create a mapping of winning neurons to their corresponding data points. It processes the data in batches to save memory. The hit map is built on CPU, but the calculations are done on GPU if available. Args: data (torch.Tensor): input data tensor [num_samples, num_features] or [num_features] return_indices (bool, optional): If True, return indices instead of data points. Defaults to False. batch_size (int, optional): Size of batches to process. Defaults to 1024. Returns: Dict[Tuple[int, int], Any]: Dictionary mapping bmus to data samples or indices """ # Ensure batch compatibility if data.dim() == 1: data = data.unsqueeze(0) # Initialize the map on CPU bmus_data_map = defaultdict(list) # Process data in batches to save GPU memory num_samples = data.shape[0] num_batches = (num_samples + batch_size - 1) // batch_size for batch_idx in range(num_batches): # Retrieve corresponding batches and move them to device start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, num_samples) current_batch_size = end_idx - start_idx batch_data = data[start_idx:end_idx].to(self.device) # Get BMUs for this batch batch_bmus = self.identify_bmus(batch_data) # Handle special case when batch has only one sample if current_batch_size == 1: # If only one sample, ensure batch_bmus is properly shaped if batch_bmus.dim() == 1: batch_bmus = batch_bmus.unsqueeze(0) row, col = batch_bmus[0] bmu_pos = (int(row.item()), int(col.item())) if return_indices: bmus_data_map[bmu_pos].append(start_idx) else: bmus_data_map[bmu_pos].append(batch_data[0].cpu()) # Otherwise, process multiple samples normally else: # Add the BMUs to the map for i, (row, col) in enumerate(batch_bmus): # Convert BMU coordinates to integer tuple for dictionary key bmu_pos = (int(row.item()), int(col.item())) # Global index for this data point global_idx = start_idx + i # Add to map based on return_indices flag if return_indices: bmus_data_map[bmu_pos].append(global_idx) else: # Store the data on CPU to save GPU memory bmus_data_map[bmu_pos].append(batch_data[i].cpu()) # Clean up GPU memory del batch_data, batch_bmus if torch.cuda.is_available(): torch.cuda.empty_cache() # Convert lists to tensors if returning data points if not return_indices: for bmu in bmus_data_map: bmus_data_map[bmu] = torch.stack(bmus_data_map[bmu]) return bmus_data_map
[docs] def build_metric_map( self, data: torch.Tensor, target: torch.Tensor, reduction_parameter: str, ) -> torch.Tensor: """Calculate neurons' metrics based on target values. Args: data (torch.Tensor): Input data tensor [batch_size, num_features] target (torch.Tensor): Labels tensor for data points [batch_size] reduction_parameter (str): Decide the calculation to apply to each neuron, 'mean' or 'std'. Returns: torch.Tensor: Metric map based on the reduction parameter. """ epsilon = 1e-8 bmus_map = self.build_bmus_data_map(data, return_indices=True) metric_map = torch.full((self.x, self.y), float("nan")) # For each activated neuron, calculate the corresponding target metric for bmu_pos, samples_indices in bmus_map.items(): if len(samples_indices) > 0: if reduction_parameter == "mean": metric_map[bmu_pos] = torch.mean(target[samples_indices]) elif reduction_parameter == "std": if len(samples_indices) > 1: metric_map[bmu_pos] = torch.std( target[samples_indices], unbiased=True ) else: metric_map[bmu_pos] = ( epsilon # Ensure visualization with a non-zero value ) return metric_map
[docs] def build_score_map( self, data: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """Calculate neurons' score based on target values. Args: data (torch.Tensor): Input data tensor [batch_size, num_features] target (torch.Tensor): Labels tensor for data points [batch_size] Returns: torch.Tensor: Score map based on a chosen score function: std_neuron / sqrt(n_neuron) * log(N_data/n_neuron). The score combines the standard error with a term penalizing uneven sample distribution across neurons. Lower scores indicate better neuron representativeness. """ epsilon = 1e-8 bmus_map = self.build_bmus_data_map(data, return_indices=True) score_map = torch.full((self.x, self.y), float("nan")) # For each activated neurons, calculate the corresponding target metric for bmu_pos, samples_indices in bmus_map.items(): if len(samples_indices) > 0: # Consider neuron with multiple elements if len(samples_indices) > 1: std = torch.std(target[samples_indices], unbiased=True) n_samples = torch.tensor(len(samples_indices), dtype=torch.float32) total_samples = torch.tensor(len(data), dtype=torch.float32) neuron_score = (std / torch.sqrt(n_samples)) * torch.log( total_samples / n_samples ) # Consider neuron with a unique element else: # Tensor to initialize tensor from scalars and ensure visualization with a non-zero value neuron_score = torch.tensor(epsilon, dtype=torch.float32) score_map[bmu_pos] = ( round(neuron_score.item(), 2) if neuron_score > epsilon else epsilon ) return score_map
[docs] def build_rank_map( self, data: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """Build a map of neuron ranks based on their target value standard deviations. Args: data (torch.Tensor): Input data tensor [batch_size, num_features] target (torch.Tensor): Labels tensor for data points [batch_size] Returns: torch.Tensor: Rank map where each neuron's value is its rank (1 = lowest std = best) """ bmus_map = self.build_bmus_data_map(data, return_indices=True) neuron_stds = torch.full((self.x, self.y), float("nan")) # Calculate standard deviation for each neuron active_neurons = 0 for bmu_pos, sample_indices in bmus_map.items(): if len(sample_indices) > 0: active_neurons += 1 # Consider neuron with multiple elements if len(sample_indices) > 1: neuron_stds[bmu_pos] = torch.std( target[sample_indices], unbiased=True ).item() # Use unbiased estimator for better small sample handling # Consider neuron with a unique element else: neuron_stds[bmu_pos] = 0.0 # rank_map = torch.full((self.x, self.y), float("nan"), device=self.device) rank_map = torch.full((self.x, self.y), float("nan")) # Get mask to retrieve indices of non-NaN values valid_mask = ~torch.isnan(neuron_stds) valid_stds = neuron_stds[valid_mask] if len(valid_stds) > 0: # Sort stds in descending order and get ranks (+ 1 to make ranks 1-based) ranks = torch.argsort(valid_stds, descending=True).argsort() + 1 # Ensure there are as many ranks as activated neurons assert ( len(ranks) == active_neurons ), f"Rank count ({len(ranks)}) doesn't match active neurons ({active_neurons})" # Place ranks back in the map rank_map[valid_mask] = ranks.float() return rank_map
[docs] def build_classification_map( self, data: torch.Tensor, target: torch.Tensor, neighborhood_order: int = 1, ) -> torch.Tensor: """ Build a classification map where each neuron is assigned the most frequent label. In case of a tie, consider labels from neighboring neurons. If there are no neighboring neurons or a second tie, then randomly select one of the top label. Args: data (torch.Tensor): Input data tensor [batch_size, num_features] target (torch.Tensor): Labels tensor for data points [batch_size]. They are assumed to be encoded with value > 1 for decent visualization. neighborhood_order (int, optional): Neighborhood order to consider for tie-breaking. Defaults to 1. Returns: torch.Tensor: Classification map with the most frequent label for each neuron """ bmus_map = self.build_bmus_data_map(data, return_indices=True) classification_map = torch.full((self.x, self.y), float("nan")) neighborhood_offsets = get_all_neighbors_up_to_order( topology=self.topology, max_order=neighborhood_order, ) # Iterate through each activated neuron for bmu_pos, sample_indices in bmus_map.items(): if len(sample_indices) > 0: """ Retrieve the labels of all samples attached to current neuron Find the most common one Check if there is a tie with another label """ neuron_labels = target[sample_indices].cpu().numpy() label_counts = Counter(neuron_labels) max_count = max(label_counts.values()) top_labels = [ label for label, count in label_counts.items() if count == max_count ] """ If there is not tie, assign the most common label to the neuron. In case of a tie, consider labels from neighboring neurons to break it. """ if len(top_labels) == 1: classification_map[bmu_pos] = torch.tensor( top_labels[0], dtype=classification_map.dtype ) # Convert NumPy value to tensor scalar else: neighbor_labels = [] row, col = bmu_pos for dx, dy in neighborhood_offsets: neighbor_row = row + dx neighbor_col = col + dy if ( 0 <= neighbor_row < self.x and 0 <= neighbor_col < self.y and (neighbor_row, neighbor_col) in bmus_map ): neighbor_samples_indices = bmus_map[ (neighbor_row, neighbor_col) ] neighbor_labels.extend( target[neighbor_samples_indices].cpu().numpy() ) # After collecting all neighbor labels, recompute label counts with neighborhood labels. if neighbor_labels: expanded_label_counts = Counter(neighbor_labels) max_neighbor_count = max(expanded_label_counts.values()) top_neighbor_labels = [ label for label, count in expanded_label_counts.items() if count == max_neighbor_count ] # If there is a tie with neighbor labels, choose randomly between top labels (including neighbors). if len(top_neighbor_labels) == 1: classification_map[bmu_pos] = torch.tensor( top_neighbor_labels[0], dtype=classification_map.dtype ) else: # Choose randomly and convert to tensor chosen_label = random.choice(top_neighbor_labels) classification_map[bmu_pos] = torch.tensor( chosen_label, dtype=classification_map.dtype ) # If there are no neighbor labels, choose randomly between previous top labels. else: # Choose randomly and convert to tensor chosen_label = random.choice(top_labels) classification_map[bmu_pos] = torch.tensor( chosen_label, dtype=classification_map.dtype ) return classification_map