Source code for torchsom.utils.initialization

import warnings
from typing import Optional, Union

import torch

from ..utils.grid import adjust_meshgrid_topology


[docs] def random_init( weights: torch.Tensor, data: torch.Tensor, device: str = "cuda" if torch.cuda.is_available() else "cpu", ) -> torch.Tensor: """Initialize SOM weights by sampling random data points. Args: weights (torch.Tensor): Weight tensor to initialize [row_neurons, col_neurons, num_features] data (torch.Tensor): Input data tensor to sample from [batch_size, num_features] device (str, optional): Device for tensor computations. Defaults to "cuda" if available, else "cpu". Returns: torch.Tensor: Initialized weights """ # Ensure data is on the correct device data = data.to(device) try: # Generate random indices for sampling indices = torch.randint( 0, len(data), (weights.shape[0], weights.shape[1]), device=device ) # Sample data points and assign to weights sampled_weights = data[indices] return sampled_weights # ! Modification to test # # Return value ignores the original weights tensor # weights.copy_(data[indices]) # return weights except RuntimeError as e: raise RuntimeError(f"Random initialization failed: {str(e)}")
[docs] def pca_init( weights: torch.Tensor, data: torch.Tensor, topology: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", ) -> torch.Tensor: """Initialize SOM weights using PCA for faster convergence. Args: weights (torch.Tensor): Weight tensor to initialize [row_neurons, col_neurons, num_features] data (torch.Tensor): Input data tensor [batch_size, num_features] topology (str): Grid configuration, "rectangular" or "hexagonal" device (str, optional): Device for tensor computations. Defaults to "cuda" if available, else "cpu". Returns: torch.Tensor: Initialized weights """ # Ensure data is on the correct device data = data.to(device) if weights.shape[2] == 1: raise ValueError("Data needs at least 2 features for PCA initialization") if weights.shape[0] == 1 or weights.shape[1] == 1: warnings.warn("PCA initialization may be inappropriate for 1D map") try: # Center the data efficiently using running mean data_mean = data.mean(dim=0, keepdim=True) data_centered = data - data_mean # Compute covariance matrix with improved numerical stability n_samples = len(data) if n_samples == 1: raise ValueError("Cannot perform PCA on a single sample") cov = torch.mm(data_centered.T, data_centered) / (n_samples - 1) # Try SVD first (more stable than eigendecomposition) try: U, S, V = torch.linalg.svd( cov, driver=None, # Default is None, but also: "gesvd" (small), "gesvdj" (medium), and "gesvda" (large) full_matrices=True, # Default is True ) pc = V[:2] # Take first two principal components except RuntimeError: warnings.warn("SVD failed, falling back to eigendecomposition") eigenvalues, eigenvectors = torch.linalg.eigh(cov) idx = torch.argsort( eigenvalues, descending=True ) # Sort eigenvectors by eigenvalues in descending order pc = eigenvectors[ :, idx[:2] ].T # Works properly ! Results seems identical to driver=None # Create coordinate grid for initialization x_coords = torch.linspace(-1, 1, weights.shape[0], device=device) y_coords = torch.linspace(-1, 1, weights.shape[1], device=device) grid_x, grid_y = torch.meshgrid(x_coords, y_coords, indexing="ij") adj_grid_x, adj_grid_y = adjust_meshgrid_topology( xx=grid_x, yy=grid_y, topology=topology ) # Initialize weights using broadcasting pca_weights = adj_grid_x.unsqueeze(-1) * pc[0].unsqueeze(0).unsqueeze( 0 ) + adj_grid_y.unsqueeze(-1) * pc[1].unsqueeze(0).unsqueeze(0) # Scale weights to match data distribution weights_std = pca_weights.std() if weights_std > 0: pca_weights = pca_weights * (data.std() / weights_std) # Add back the mean return pca_weights + data_mean except Exception as e: warnings.warn( f"PCA initialization failed: {str(e)}. Falling back to random initialization" ) return random_init(weights, data, device)
[docs] def initialize_weights( weights: torch.Tensor, data: torch.Tensor, mode: str = "random", topology: str = "rectangular", device: str = "cuda" if torch.cuda.is_available() else "cpu", ) -> torch.Tensor: """Main function to initialize weights based on specified method. Args: weights (torch.Tensor): Weight tensor to initialize [row_neurons, col_neurons, num_features] data (torch.Tensor): Input data tensor [batch_size, num_features] mode (str, optional): Initialization method, "random" or "pca". Defaults to "random". topology (str, optional): Grid configuration, "rectangular" or "hexagonal". Defaults to "rectangular". device (str, optional): Device for tensor computations. Defaults to "cuda" if available, else "cpu". Returns: torch.Tensor: Initialized weights Raises: ValueError: If an invalid initialization mode is provided """ if data.shape[1] != weights.shape[2]: raise ValueError( f"Input data dimension ({data.shape[1]}) and weights dimension ({weights.shape[2]}) don't match" ) if mode == "random": return random_init(weights, data, device) elif mode == "pca": return pca_init(weights, data, topology, device) else: raise ValueError( "The only method to initialize the weights are 'random' or 'pca'." )