Skip to content

API Reference

Main Interface

ncut_pytorch.ncut.Ncut

Class interface for Normalized Cut, save states of nystrom approximation, can be used to transform new data.

Source code in ncut_pytorch/ncut.py
class Ncut:
    """
    Class interface for Normalized Cut, save states of nystrom approximation, can be used to transform new data.
    """

    def __init__(
            self,
            n_eig: int = 100,
            quantile_sigma: float = 0.25,
            quantile_sigma_repulsion: float = 0.20,
            sigma: float | None = None,
            repulsion_sigma: float | None = None,
            repulsion_weight: float | None = None,
            affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
            extrapolation_factor: float = 1.0,
            exact_gradient: bool = False,
            device: str | None = None,
            **kwargs,
    ):
        """

        Args:       
            n_eig (int): number of eigenvectors
            n_eig (int): number of eigenvectors
            quantile_sigma (float): quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors
            quantile_sigma_repulsion (float): quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors
            sigma (float): affinity parameter, override d_sigma if provided
            repulsion_sigma (float): (if use repulsion) repulsion sigma parameter, default None (no repulsion)
            repulsion_weight (float): (if use repulsion) repulsion weight, default 0.2
            affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
            extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
            exact_gradient (bool): use full spectrum and exact gradient, can be slower and unstable, default False            device (str): device, default 'auto' (auto detect GPU)

        Examples:
            >>> from ncut_pytorch import Ncut
            >>> import torch
            >>> X = torch.rand(10000, 100)
            >>> ncut = Ncut(n_eig=20)
            >>> eigvec = ncut.fit_transform(X)
            >>> eigval = ncut.eigval
            >>> print(eigvec.shape, eigval.shape)  # (10000, 20) (20,)
            >>> 
            >>> # transform new data
            >>> new_X = torch.rand(500, 100)
            >>> new_eigvec = ncut.transform(new_X)
            >>> print(new_eigvec.shape)  # (500, 20)
        """
        self.n_eig = n_eig
        self.quantile_sigma = quantile_sigma
        self.quantile_sigma_repulsion = quantile_sigma_repulsion
        self.sigma = sigma
        self.repulsion_sigma = repulsion_sigma
        self.repulsion_weight = repulsion_weight
        self.extrapolation_factor = extrapolation_factor
        self.exact_gradient = exact_gradient
        self.device = device
        self.affinity_fn = affinity_fn
        self.kwargs = kwargs

        self._nystrom_x = None
        self._nystrom_eigvec = None
        self._eigval = None

        self._kway_R: dict[tuple[int, int], torch.Tensor] = {}

    @property
    def eigval(self) -> torch.Tensor:
        return self._eigval

    def fit(self, X: torch.Tensor) -> "Ncut":
        """
        Fit the Ncut model to the input features. save states of nystrom approximation.

        Args:
            X (torch.Tensor): input features, shape (N, D)
        Returns:
            ncut (Ncut): Ncut instance
        """
        eigvec, eigval, indices, sigma = \
            ncut_fn(
                X,
                n_eig=self.n_eig,
                quantile_sigma=self.quantile_sigma,
                quantile_sigma_repulsion=self.quantile_sigma_repulsion,
                sigma=self.sigma,
                repulsion_sigma=self.repulsion_sigma,
                repulsion_weight=self.repulsion_weight,
                device=self.device,
                exact_gradient=self.exact_gradient,
                no_propagation=True,
                affinity_fn=self.affinity_fn,
                **self.kwargs
            )
        # store Ncut state to use in transform()
        self._nystrom_x = X[indices]
        self._nystrom_eigvec = eigvec
        self._eigval = eigval
        self.sigma = sigma
        return self

    def transform(self, X: torch.Tensor) -> torch.Tensor:
        """
        Transform new data using the fitted Ncut model and it's saved states of nystrom approximation.

        Args:
            X (torch.Tensor): input features, shape (N, D)

        Returns:
            eigvec (torch.Tensor): eigenvectors, shape (N, n_eig)
        """
        self._check_is_fitted()

        # propagate eigenvectors from subgraph to full graph
        eigvec = nystrom_propagate(
            self._nystrom_eigvec,
            X,
            self._nystrom_x,
            extrapolation_factor=self.extrapolation_factor,
            device=self.device,
            **self.kwargs
        )
        return eigvec

    def fit_transform(self, X: torch.Tensor) -> torch.Tensor:
        """

        Args:
            X (torch.Tensor): input features, shape (N, D)

        Returns:
            eigvec (torch.Tensor): eigenvectors, shape (N, n_eig)
        """
        return self.fit(X).transform(X)

    def __call__(self, X: torch.Tensor) -> torch.Tensor:
        return self.fit_transform(X)

    def _check_is_fitted(self) -> None:
        if self._nystrom_x is None or self._nystrom_eigvec is None:
            raise ValueError("Ncut has not been fitted yet. Call fit() first.")

    def _validate_kway_params(self, n_clusters: int, n_eig: int) -> None:
        self._check_is_fitted()

        if not isinstance(n_clusters, int) or not isinstance(n_eig, int):
            raise TypeError("n_clusters and n_eig must be integers.")
        if n_clusters <= 0 or n_eig <= 0:
            raise ValueError("n_clusters and n_eig must be positive.")
        if n_eig < 2:
            raise ValueError("n_eig must be at least 2 for k-way discretization.")
        if n_eig > self._nystrom_eigvec.shape[1]:
            raise ValueError(
                f"n_eig={n_eig} exceeds fitted eigenvector count {self._nystrom_eigvec.shape[1]}."
            )

    def kway_fit(self, n_clusters: int, n_eig: int, kmeans_iter: int = 10) -> "Ncut":
        """
        Fit and cache a k-way rotation matrix for the fitted eigenvectors.

        Args:
            n_clusters (int): number of output clusters.
            n_eig (int): number of leading eigenvectors to use.
            kmeans_iter (int): number of k-means refinement iterations.

        Returns:
            Ncut: current instance.
        """
        self._validate_kway_params(n_clusters=n_clusters, n_eig=n_eig)

        R = quick_kway(
            self._nystrom_eigvec[:, :n_eig],
            n_clusters=n_clusters,
            n_eig=n_eig,
            n_sample=self._nystrom_eigvec.shape[0],
            device=self.device,
            kmeans_iter=kmeans_iter,
            ret_R=True,
        )
        self._kway_R[(n_clusters, n_eig)] = R.cpu()
        return self

    def kway_transform(self, X: torch.Tensor, n_clusters: int, n_eig: int) -> torch.Tensor:
        """
        Transform data with a previously fitted k-way rotation matrix.

        Args:
            X (torch.Tensor): input features, shape (N, D).
            n_clusters (int): number of output clusters.
            n_eig (int): number of leading eigenvectors to use.

        Returns:
            torch.Tensor: rotated eigenvectors, shape (N, n_clusters).
        """
        self._validate_kway_params(n_clusters=n_clusters, n_eig=n_eig)

        cache_key = (n_clusters, n_eig)
        if cache_key not in self._kway_R:
            raise ValueError(
                "K-way rotation has not been fitted for this configuration. "
                "Call kway_fit() with the same n_clusters and n_eig first."
            )

        eigvec = self.transform(X)[:, :n_eig]
        R = self._kway_R[cache_key]
        device = auto_device(self.device)

        return chunked_matmul(eigvec, R, device=device, large_device=eigvec.device)

__init__(n_eig=100, quantile_sigma=0.25, quantile_sigma_repulsion=0.2, sigma=None, repulsion_sigma=None, repulsion_weight=None, affinity_fn=rbf_affinity, extrapolation_factor=1.0, exact_gradient=False, device=None, **kwargs)

Parameters:

Name Type Description Default
n_eig int

number of eigenvectors

100
n_eig int

number of eigenvectors

100
quantile_sigma float

quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors

0.25
quantile_sigma_repulsion float

quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors

0.2
sigma float

affinity parameter, override d_sigma if provided

None
repulsion_sigma float

(if use repulsion) repulsion sigma parameter, default None (no repulsion)

None
repulsion_weight float

(if use repulsion) repulsion weight, default 0.2

None
affinity_fn callable

affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix

rbf_affinity
extrapolation_factor float

control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0

1.0
exact_gradient bool

use full spectrum and exact gradient, can be slower and unstable, default False device (str): device, default 'auto' (auto detect GPU)

False

Examples:

>>> from ncut_pytorch import Ncut
>>> import torch
>>> X = torch.rand(10000, 100)
>>> ncut = Ncut(n_eig=20)
>>> eigvec = ncut.fit_transform(X)
>>> eigval = ncut.eigval
>>> print(eigvec.shape, eigval.shape)  # (10000, 20) (20,)
>>> 
>>> # transform new data
>>> new_X = torch.rand(500, 100)
>>> new_eigvec = ncut.transform(new_X)
>>> print(new_eigvec.shape)  # (500, 20)
Source code in ncut_pytorch/ncut.py
def __init__(
        self,
        n_eig: int = 100,
        quantile_sigma: float = 0.25,
        quantile_sigma_repulsion: float = 0.20,
        sigma: float | None = None,
        repulsion_sigma: float | None = None,
        repulsion_weight: float | None = None,
        affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
        extrapolation_factor: float = 1.0,
        exact_gradient: bool = False,
        device: str | None = None,
        **kwargs,
):
    """

    Args:       
        n_eig (int): number of eigenvectors
        n_eig (int): number of eigenvectors
        quantile_sigma (float): quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors
        quantile_sigma_repulsion (float): quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors
        sigma (float): affinity parameter, override d_sigma if provided
        repulsion_sigma (float): (if use repulsion) repulsion sigma parameter, default None (no repulsion)
        repulsion_weight (float): (if use repulsion) repulsion weight, default 0.2
        affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
        extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
        exact_gradient (bool): use full spectrum and exact gradient, can be slower and unstable, default False            device (str): device, default 'auto' (auto detect GPU)

    Examples:
        >>> from ncut_pytorch import Ncut
        >>> import torch
        >>> X = torch.rand(10000, 100)
        >>> ncut = Ncut(n_eig=20)
        >>> eigvec = ncut.fit_transform(X)
        >>> eigval = ncut.eigval
        >>> print(eigvec.shape, eigval.shape)  # (10000, 20) (20,)
        >>> 
        >>> # transform new data
        >>> new_X = torch.rand(500, 100)
        >>> new_eigvec = ncut.transform(new_X)
        >>> print(new_eigvec.shape)  # (500, 20)
    """
    self.n_eig = n_eig
    self.quantile_sigma = quantile_sigma
    self.quantile_sigma_repulsion = quantile_sigma_repulsion
    self.sigma = sigma
    self.repulsion_sigma = repulsion_sigma
    self.repulsion_weight = repulsion_weight
    self.extrapolation_factor = extrapolation_factor
    self.exact_gradient = exact_gradient
    self.device = device
    self.affinity_fn = affinity_fn
    self.kwargs = kwargs

    self._nystrom_x = None
    self._nystrom_eigvec = None
    self._eigval = None

    self._kway_R: dict[tuple[int, int], torch.Tensor] = {}

fit(X)

Fit the Ncut model to the input features. save states of nystrom approximation.

Parameters:

Name Type Description Default
X Tensor

input features, shape (N, D)

required

Returns: ncut (Ncut): Ncut instance

Source code in ncut_pytorch/ncut.py
def fit(self, X: torch.Tensor) -> "Ncut":
    """
    Fit the Ncut model to the input features. save states of nystrom approximation.

    Args:
        X (torch.Tensor): input features, shape (N, D)
    Returns:
        ncut (Ncut): Ncut instance
    """
    eigvec, eigval, indices, sigma = \
        ncut_fn(
            X,
            n_eig=self.n_eig,
            quantile_sigma=self.quantile_sigma,
            quantile_sigma_repulsion=self.quantile_sigma_repulsion,
            sigma=self.sigma,
            repulsion_sigma=self.repulsion_sigma,
            repulsion_weight=self.repulsion_weight,
            device=self.device,
            exact_gradient=self.exact_gradient,
            no_propagation=True,
            affinity_fn=self.affinity_fn,
            **self.kwargs
        )
    # store Ncut state to use in transform()
    self._nystrom_x = X[indices]
    self._nystrom_eigvec = eigvec
    self._eigval = eigval
    self.sigma = sigma
    return self

fit_transform(X)

Parameters:

Name Type Description Default
X Tensor

input features, shape (N, D)

required

Returns:

Name Type Description
eigvec Tensor

eigenvectors, shape (N, n_eig)

Source code in ncut_pytorch/ncut.py
def fit_transform(self, X: torch.Tensor) -> torch.Tensor:
    """

    Args:
        X (torch.Tensor): input features, shape (N, D)

    Returns:
        eigvec (torch.Tensor): eigenvectors, shape (N, n_eig)
    """
    return self.fit(X).transform(X)

kway_fit(n_clusters, n_eig, kmeans_iter=10)

Fit and cache a k-way rotation matrix for the fitted eigenvectors.

Parameters:

Name Type Description Default
n_clusters int

number of output clusters.

required
n_eig int

number of leading eigenvectors to use.

required
kmeans_iter int

number of k-means refinement iterations.

10

Returns:

Name Type Description
Ncut Ncut

current instance.

Source code in ncut_pytorch/ncut.py
def kway_fit(self, n_clusters: int, n_eig: int, kmeans_iter: int = 10) -> "Ncut":
    """
    Fit and cache a k-way rotation matrix for the fitted eigenvectors.

    Args:
        n_clusters (int): number of output clusters.
        n_eig (int): number of leading eigenvectors to use.
        kmeans_iter (int): number of k-means refinement iterations.

    Returns:
        Ncut: current instance.
    """
    self._validate_kway_params(n_clusters=n_clusters, n_eig=n_eig)

    R = quick_kway(
        self._nystrom_eigvec[:, :n_eig],
        n_clusters=n_clusters,
        n_eig=n_eig,
        n_sample=self._nystrom_eigvec.shape[0],
        device=self.device,
        kmeans_iter=kmeans_iter,
        ret_R=True,
    )
    self._kway_R[(n_clusters, n_eig)] = R.cpu()
    return self

kway_transform(X, n_clusters, n_eig)

Transform data with a previously fitted k-way rotation matrix.

Parameters:

Name Type Description Default
X Tensor

input features, shape (N, D).

required
n_clusters int

number of output clusters.

required
n_eig int

number of leading eigenvectors to use.

required

Returns:

Type Description
Tensor

torch.Tensor: rotated eigenvectors, shape (N, n_clusters).

Source code in ncut_pytorch/ncut.py
def kway_transform(self, X: torch.Tensor, n_clusters: int, n_eig: int) -> torch.Tensor:
    """
    Transform data with a previously fitted k-way rotation matrix.

    Args:
        X (torch.Tensor): input features, shape (N, D).
        n_clusters (int): number of output clusters.
        n_eig (int): number of leading eigenvectors to use.

    Returns:
        torch.Tensor: rotated eigenvectors, shape (N, n_clusters).
    """
    self._validate_kway_params(n_clusters=n_clusters, n_eig=n_eig)

    cache_key = (n_clusters, n_eig)
    if cache_key not in self._kway_R:
        raise ValueError(
            "K-way rotation has not been fitted for this configuration. "
            "Call kway_fit() with the same n_clusters and n_eig first."
        )

    eigvec = self.transform(X)[:, :n_eig]
    R = self._kway_R[cache_key]
    device = auto_device(self.device)

    return chunked_matmul(eigvec, R, device=device, large_device=eigvec.device)

transform(X)

Transform new data using the fitted Ncut model and it's saved states of nystrom approximation.

Parameters:

Name Type Description Default
X Tensor

input features, shape (N, D)

required

Returns:

Name Type Description
eigvec Tensor

eigenvectors, shape (N, n_eig)

Source code in ncut_pytorch/ncut.py
def transform(self, X: torch.Tensor) -> torch.Tensor:
    """
    Transform new data using the fitted Ncut model and it's saved states of nystrom approximation.

    Args:
        X (torch.Tensor): input features, shape (N, D)

    Returns:
        eigvec (torch.Tensor): eigenvectors, shape (N, n_eig)
    """
    self._check_is_fitted()

    # propagate eigenvectors from subgraph to full graph
    eigvec = nystrom_propagate(
        self._nystrom_eigvec,
        X,
        self._nystrom_x,
        extrapolation_factor=self.extrapolation_factor,
        device=self.device,
        **self.kwargs
    )
    return eigvec

Core Ncut Functions

ncut_pytorch.ncuts.ncut_nystrom.ncut_fn(X, n_eig=100, quantile_sigma=0.25, quantile_sigma_repulsion=0.2, sigma=None, repulsion_sigma=None, repulsion_weight=None, affinity_fn=rbf_affinity, extrapolation_factor=1.0, exact_gradient=False, device=None, make_orthogonal=False, no_propagation=False, **kwargs)

Normalized Cut, balanced sampling and nystrom approximation.

Parameters:

Name Type Description Default
X Tensor

input features, shape (N, D)

required
n_eig int

number of eigenvectors

100
quantile_sigma float

quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors

0.25
quantile_sigma_repulsion float

quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors

0.2
sigma float

affinity parameter, override d_sigma if provided

None
repulsion_sigma float

(if use repulsion) repulsion sigma parameter, default None (no repulsion)

None
repulsion_weight float

(if use repulsion) repulsion weight, default 0.2

None
affinity_fn callable

affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix

rbf_affinity
extrapolation_factor float

control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0

1.0
exact_gradient bool

use full spectrum and exact gradient, can be slower and unstable, default False

False
make_orthogonal bool

make eigenvectors orthogonal

False

Returns:

Name Type Description
eigenvectors Tensor

shape (N, n_eig)

eigenvalues Tensor

sorted in descending order, shape (n_eig,)

Examples:

>>> from ncut_pytorch import ncut_fn
>>> import torch
>>> features = torch.rand(10000, 100)
>>> eigvec, eigval = ncut_fn(features, n_eig=20)
>>> print(eigvec.shape, eigval.shape)  # (10000, 20) (20,)
Source code in ncut_pytorch/ncuts/ncut_nystrom.py
def ncut_fn(
        X: torch.Tensor,
        n_eig: int = 100,
        quantile_sigma: float = 0.25,
        quantile_sigma_repulsion: float = 0.20,
        sigma: float | None = None,
        repulsion_sigma: float | None = None,
        repulsion_weight: float | None = None,
        affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
        extrapolation_factor: float = 1.0,
        exact_gradient: bool = False,
        device: str | None = None,
        make_orthogonal: bool = False,
        no_propagation: bool = False,
        **kwargs,
) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]]:
    """Normalized Cut, balanced sampling and nystrom approximation.

    Args:
        X (torch.Tensor): input features, shape (N, D)
        n_eig (int): number of eigenvectors
        quantile_sigma (float): quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors
        quantile_sigma_repulsion (float): quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors
        sigma (float): affinity parameter, override d_sigma if provided
        repulsion_sigma (float): (if use repulsion) repulsion sigma parameter, default None (no repulsion)
        repulsion_weight (float): (if use repulsion) repulsion weight, default 0.2
        affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
        extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
        exact_gradient (bool): use full spectrum and exact gradient, can be slower and unstable, default False
        make_orthogonal (bool): make eigenvectors orthogonal

    Returns:
        eigenvectors (torch.Tensor): shape (N, n_eig)
        eigenvalues (torch.Tensor): sorted in descending order, shape (n_eig,)

    Examples:
        >>> from ncut_pytorch import ncut_fn
        >>> import torch
        >>> features = torch.rand(10000, 100)
        >>> eigvec, eigval = ncut_fn(features, n_eig=20)
        >>> print(eigvec.shape, eigval.shape)  # (10000, 20) (20,)
    """
    config = NystromConfig()
    config.update(kwargs)
    device = auto_device(X.device, device)

    # subsample for nystrom approximation
    n_sample = min(config.n_sample, int(X.shape[0]*config.n_sample_max_ratio))
    if X.shape[0] > SMALL_SCALE_THRESHOLD:
        nystrom_indices = farthest_point_sampling(X, n_sample=n_sample, device=device)
    else:
        nystrom_indices = torch.arange(X.shape[0])
    nystrom_X = X[nystrom_indices].to(device)

    sigma, repulsion_sigma = find_optimal_sigma(nystrom_X, quantile_sigma, quantile_sigma_repulsion, sigma, repulsion_sigma, affinity_fn)

    if repulsion_sigma and repulsion_weight:
        nystrom_eigvec, eigval = ncut_with_repulsion(nystrom_X, n_eig, sigma, 
            repulsion_sigma, repulsion_weight, affinity_fn, exact_gradient)
    else:
        A = affinity_fn(nystrom_X, sigma=sigma)
        nystrom_eigvec, eigval = _plain_ncut(A, n_eig, exact_gradient)

    if no_propagation:
        return nystrom_eigvec, eigval, nystrom_indices, sigma

    # propagate eigenvectors from subgraph to full graph
    eigvec = nystrom_propagate(
        nystrom_eigvec,
        X,
        nystrom_X,
        extrapolation_factor=extrapolation_factor,
        n_neighbors=config.n_neighbors,
        n_sample=config.n_sample2,
        device=device,
    )

    # post-hoc orthogonalization
    if make_orthogonal:
        eigvec = gram_schmidt(eigvec)

    return eigvec, eigval

ncut_pytorch.ncuts.ncut_nystrom.nystrom_propagate(nystrom_out, X, nystrom_X, extrapolation_factor=1.0, device=None, return_indices=False, **kwargs)

propagate output from nystrom sampled nodes to all nodes, use a weighted sum of the nearest neighbors to propagate the output.

Parameters:

Name Type Description Default
nystrom_out Tensor

output from nystrom sampled nodes, shape (m, D)

required
X Tensor

input features for all nodes, shape (N, D)

required
nystrom_X Tensor

input features from nystrom sampled nodes, shape (m, D)

required
extrapolation_factor float

control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0

1.0
device str

device to use for computation, if 'auto', will detect GPU automatically

None
return_indices bool

whether to return the indices used for propagation

False

Returns:

Type Description
Union[Tensor, tuple[Tensor, Tensor]]

torch.Tensor: output propagated by nearest neighbors, shape (N, D)

Source code in ncut_pytorch/ncuts/ncut_nystrom.py
def nystrom_propagate(
        nystrom_out: torch.Tensor,
        X: torch.Tensor,
        nystrom_X: torch.Tensor,
        extrapolation_factor: float = 1.0,
        device: str = None,
        return_indices: bool = False,
        **kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    """propagate output from nystrom sampled nodes to all nodes,
    use a weighted sum of the nearest neighbors to propagate the output.

    Args:   
        nystrom_out (torch.Tensor): output from nystrom sampled nodes, shape (m, D)
        X (torch.Tensor): input features for all nodes, shape (N, D)
        nystrom_X (torch.Tensor): input features from nystrom sampled nodes, shape (m, D)
        extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
        device (str): device to use for computation, if 'auto', will detect GPU automatically
        return_indices (bool): whether to return the indices used for propagation

    Returns:
        torch.Tensor: output propagated by nearest neighbors, shape (N, D)
    """
    if X.shape[0] <= SMALL_SCALE_THRESHOLD and nystrom_X.shape == X.shape and torch.allclose(nystrom_X.to(X.device), X, atol=1e-6):
        # skip propagation if nystrom_out is the same as X, for small scale graph that don't need nystrom approximation
        if return_indices:
            return nystrom_out, np.arange(X.shape[0])
        return nystrom_out

    config = NystromConfig()
    config.update(kwargs)

    device = auto_device(nystrom_out.device, device)
    output_device = X.device
    indices = farthest_point_sampling(nystrom_out, config.n_sample2, device=device)
    nystrom_out = nystrom_out[indices].to(device).contiguous()
    nystrom_X = nystrom_X[indices].to(device).contiguous()

    sigma = find_sigma_by_degree(nystrom_X, affinity_fn=rbf_affinity, quantile_sigma=0.25)
    sigma = sigma * extrapolation_factor

    D = rbf_affinity(nystrom_X, sigma=sigma).mean(1)
    nystrom_x_sq = nystrom_X.pow(2).sum(dim=1).unsqueeze(0)

    n_neighbors = int(min(config.n_neighbors, len(indices)*config.n_neighbors_max_ratio))
    n_neighbors = max(n_neighbors, 4)
    n_chunk = _find_max_chunk_size(X, nystrom_X, device)
    offsets_cache: dict[int, torch.Tensor] = {}

    all_outs = torch.empty((X.shape[0], nystrom_out.shape[-1]), device=output_device, dtype=nystrom_out.dtype)
    for i in range(0, X.shape[0], n_chunk):
        end = min(i + n_chunk, X.shape[0])
        Xi = X[i:end].to(device)

        _Ai, _indices = _rbf_topk_from_squared_distance(Xi, nystrom_X, nystrom_x_sq, sigma, n_neighbors)

        _Di = D[_indices].sum(1)
        _Ai = _Ai / _Di[:, None]
        out = _weighted_neighbor_sum(_Ai, _indices, nystrom_out, offsets_cache)

        all_outs[i:end] = out.to(output_device)

    if return_indices:
        return all_outs, indices
    return all_outs

ncut_pytorch.ncuts.ncut_kway.kway_ncut(eigvec, n_clusters=None, device=None, ret_R=False, **kwargs)

K-way Ncut discretization.

Parameters:

Name Type Description Default
eigvec Tensor

Eigenvectors from Ncut output.

required
n_clusters int | None

Number of clusters to use. If None, will use the number of eigenvectors.

None
device str | None

Device to use for computation.

None
ret_R bool

Whether to return the rotation matrix.

False

Returns:

Type Description
Tensor | tuple[Tensor, Tensor]

Discretized eigenvectors (rotation matrix if ret_R=True).

Source code in ncut_pytorch/ncuts/ncut_kway.py
def kway_ncut(
    eigvec: torch.Tensor,          # [n, k]
    n_clusters: int | None = None,
    device: str | None = None,
    ret_R: bool = False,
    **kwargs,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:  # [n, k] or ([n, k], [k, k])
    """
    K-way Ncut discretization.

    Args:
        eigvec: Eigenvectors from Ncut output.
        n_clusters: Number of clusters to use. If None, will use the number of eigenvectors.
        device: Device to use for computation.
        ret_R: Whether to return the rotation matrix.

    Returns:
        Discretized eigenvectors (rotation matrix if ret_R=True).
    """
    n_clusters = n_clusters or eigvec.shape[1]
    R = axis_align(eigvec[:, :n_clusters], device=device, **kwargs)
    if ret_R:
        return R
    device = auto_device(eigvec.device, device)
    eigvec = chunked_matmul(eigvec[:, :n_clusters], R, device=device, large_device=eigvec.device)
    return eigvec

ncut_pytorch.ncuts.ncut_kway.axis_align(eigvec, device=None, max_iter=1000, n_sample=10240, sample_idx=None)

Multiclass Spectral Clustering (SX Yu, J Shi, 2003).

Source code in ncut_pytorch/ncuts/ncut_kway.py
@torch.no_grad()
def axis_align(
    eigvec: torch.Tensor,          # [n, k]
    device: str | None = None,
    max_iter: int = 1000,
    n_sample: int = 10240,
    sample_idx: torch.Tensor | None = None,
) -> torch.Tensor:                 # [k, k]
    """Multiclass Spectral Clustering (SX Yu, J Shi, 2003)."""
    n, k = eigvec.shape
    if sample_idx is None:
        sample_idx = farthest_point_sampling(eigvec, n_sample, device=device)
    eigvec = eigvec[sample_idx]

    eigvec = F.normalize(eigvec, dim=1)

    # Initialize R matrix with FPS
    _sample_idx = farthest_point_sampling(eigvec, k, device=device)
    R = eigvec[_sample_idx].T

    original_device = eigvec.device
    original_dtype = eigvec.dtype
    device = auto_device(original_device, device)
    eigvec = eigvec.to(device=device, dtype=torch.float32)
    R = R.to(device=device, dtype=torch.float32)

    last_obj = 0.0
    exit_loop = False
    iter_count = 0

    while not exit_loop:
        iter_count += 1

        # Discretize projected eigenvectors
        _eig_cont = eigvec @ R
        _eig_disc = _onehot_discretize(_eig_cont)
        _eig_disc = _eig_disc.to(device=device, dtype=eigvec.dtype)

        # SVD decomposition
        _out = _eig_disc.T @ eigvec
        _out_dtype = _out.dtype
        try:
            with torch.autocast(device_type=_out.device.type, enabled=False):
                if _out_dtype in (torch.float16, torch.bfloat16):
                    _out = _out.float()
                U, S, Vh = torch.linalg.svd(_out, full_matrices=False)
        except RuntimeError:
            if _out_dtype in (torch.float16, torch.bfloat16):
                _out = _out.float()
            U, S, Vh = torch.linalg.svd(_out, full_matrices=False)
        U, S, Vh = U.to(_out_dtype), S.to(_out_dtype), Vh.to(_out_dtype)
        V = Vh.T

        ncut_val = 2 * (n - torch.sum(S))

        # Check convergence
        if torch.abs(ncut_val - last_obj) < torch.finfo(torch.float32).eps or iter_count > max_iter:
            exit_loop = True
        else:
            last_obj = ncut_val
            R = V @ U.T

    R = R.to(device=original_device, dtype=original_dtype)
    R = R[:, torch.argsort(R[1])]
    return R

ncut_pytorch.ncuts.ncut_kway.quick_kway(eigvec, n_clusters=10, n_eig=10, n_sample=10240, device=None, kmeans_iter=10, ret_R=False)

Quick K-way Ncut using K-means for rotation matrix.

Source code in ncut_pytorch/ncuts/ncut_kway.py
def quick_kway(
    eigvec: torch.Tensor,          # [n, k]
    n_clusters: int = 10,
    n_eig: int = 10,
    n_sample: int = 10240,
    device: str | None = None,
    kmeans_iter: int = 10,
    ret_R: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:  # [n, k] or ([n, k], [k, k])
    """Quick K-way Ncut using K-means for rotation matrix."""
    R = _kmeans_kway(eigvec, n_clusters, n_eig, n_sample, device, kmeans_iter)
    if ret_R:
        return R
    device = auto_device(eigvec.device, device)
    eigvec = chunked_matmul(eigvec[:, :n_eig], R, device=device, large_device=eigvec.device)
    return eigvec

ncut_pytorch.ncuts.ncut_click.ncut_click_prompt(X, fg_indices, bg_indices=None, click_weight=0.5, bg_weight=0.1, n_eig=2, quantile_sigma=0.25, device=None, sigma=None, affinity_fn=rbf_affinity, exact_gradient=False, no_propagation=False, return_indices_and_sigma=False, **kwargs)

Source code in ncut_pytorch/ncuts/ncut_click.py
def ncut_click_prompt(
        X: torch.Tensor,
        fg_indices: np.ndarray,
        bg_indices: np.ndarray = None,
        click_weight: float = 0.5,
        bg_weight: float = 0.1,
        n_eig: int = 2,
        quantile_sigma: float = 0.25,
        device: str = None,
        sigma: float = None,
        affinity_fn: Callable[[torch.Tensor, torch.Tensor, float], torch.Tensor] = rbf_affinity,
        exact_gradient: bool = False,
        no_propagation: bool = False,
        return_indices_and_sigma: bool = False,
        **kwargs,
) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]]:

    config = NystromConfig()
    config.update(kwargs)

    # use GPU if available
    device = auto_device(X.device, device)

    if bg_indices is None:
        bg_indices = np.array([], dtype=np.int64)

    # subsample for nystrom approximation
    nystrom_indices = farthest_point_sampling(X, n_sample=config.n_sample, device=device)
    nystrom_indices = torch.tensor(nystrom_indices, dtype=torch.long)
    # remove fg and bg from fps_idx
    nystrom_indices = nystrom_indices[~np.isin(nystrom_indices, np.concatenate([fg_indices, bg_indices]))]
    # add fg and bg to fps_idx
    nystrom_indices = np.concatenate([fg_indices, bg_indices, nystrom_indices])
    fg_indices = np.arange(len(fg_indices))
    bg_indices = np.arange(len(bg_indices)) + len(fg_indices)
    n_fgbg = len(fg_indices) + len(bg_indices)

    nystrom_X = X[nystrom_indices].to(device)

    # find optimal sigma for affinity matrix
    if sigma is None and affinity_fn == rbf_affinity:
        sigma = find_sigma_by_degree(nystrom_X, quantile_sigma, affinity_fn)
        # TODO: change to std()
    elif sigma is None and affinity_fn == cosine_affinity:
        sigma = 0.5

    # compute Ncut on the nystrom sampled subgraph
    A = affinity_fn(nystrom_X, sigma=sigma)
    A = normalize_affinity(A)

    # modify the affinity from the clicks
    X_click = 1 * A[fg_indices].mean(0)
    if len(bg_indices) > 0:
        X_click = X_click - bg_weight * A[bg_indices].mean(0)

    X_click = X_click * A.shape[0]

    A_click = affinity_fn(X_click.unsqueeze(1), sigma=0.5)
    A_click = normalize_affinity(A_click)

    _A = click_weight * A_click + (1 - click_weight) * A

    nystrom_eigvec, eigval = _plain_ncut(_A, n_eig, exact_gradient=exact_gradient)

    if no_propagation:
        return nystrom_eigvec, eigval, nystrom_indices, sigma

    # propagate eigenvectors from subgraph to full graph
    eigvec, nystrom_indices2 = nystrom_propagate(
        nystrom_eigvec,
        X,
        nystrom_X,
        n_neighbors=config.n_neighbors,
        n_sample=config.n_sample2,
        device=device,
        return_indices=True,
    )


    if return_indices_and_sigma:
        indices = nystrom_indices[nystrom_indices2]
        return eigvec, eigval, indices, sigma

    return eigvec, eigval

Predictors

ncut_pytorch.predictor.predictor.NcutPredictor

Source code in ncut_pytorch/predictor/predictor.py
class NcutPredictor:
    _initialized: bool = False
    device: str = 'cpu'
    color_method: str = 'mspace'
    # ncut_fn: Callable = partial(ncut_fn, affinity_fn=cosine_affinity, sigma=0.4, repulsion_sigma=0.3)
    ncut_fn: Callable = partial(ncut_fn)

    def __init__(self):
        self._features: torch.Tensor
        self._hierarchy_assign: List[torch.Tensor]
        self._eigvecs: torch.Tensor
        self._color_palette: torch.Tensor

        # inference states
        self._nystrom_indices: torch.Tensor
        self._sigma: float
        self._click_eigvecs: torch.Tensor
        self._R: torch.Tensor
        self._fg_idx: int
        self._bg_idx: int

        # kway ncut states
        self._kway_sample_idx: torch.Tensor

    def initialize(self,
                   features: torch.Tensor,
                   n_segments: Union[List[int], int] = (5, 25, 50, 100, 250)
                   ) -> None:
        self._features = features
        if isinstance(n_segments, int):
            n_segments = [n_segments]
        self.refresh_eigvecs(max(n_segments))
        self._initialized = True
        self.cache_hierarchy(n_segments)
        self._color_palette = []

    def refresh_eigvecs(self, n_eig: int) -> None:
        eigvecs, eigval = self.ncut_fn(self._features, n_eig=n_eig, device=self.device)
        self._eigvecs = eigvecs
        self._kway_sample_idx = farthest_point_sampling(eigvecs, 10240, device=self.device)

    def get_n_eigvecs(self, n_eig: int) -> torch.Tensor:
        cache_hit = n_eig <= self._eigvecs.shape[1]
        if not cache_hit:
            self.refresh_eigvecs(n_eig)
        return self._eigvecs[:, :n_eig]

    def cache_hierarchy(self, n_segments: List[int]) -> None:
        hierarchy_assign = []
        for n_eig in n_segments:
            hierarchy_assign.append(self.get_n_segments(n_eig))
        self._hierarchy_assign = hierarchy_assign

    def get_n_segments(self, n_cluster: int) -> torch.Tensor:
        self.__check_initialized()
        eigvecs = self.get_n_eigvecs(n_cluster)
        # kway_eigvec = kway_ncut(eigvecs, device=self.device, sample_idx=self._kway_sample_idx)
        kway_eigvec = quick_kway(eigvecs, n_eig=n_cluster, n_clusters=n_cluster, device=self.device)
        cluster_assignment = kway_eigvec.argmax(dim=1).cpu()
        return cluster_assignment

    def get_hierarchy_masks(self, point_index: int) -> List[torch.Tensor]:
        self.__check_initialized()
        masks: List[torch.Tensor] = []
        for cluster_assignment in self._hierarchy_assign:
            cluster_idx = cluster_assignment[point_index].item()
            mask = cluster_assignment == cluster_idx
            masks.append(mask)
        return masks

    def predict_clicks(self,
                       fg_indices: np.ndarray,
                       bg_indices: np.ndarray,
                       click_weight: float,
                       **kwargs
                       ) -> Tuple[torch.Tensor, torch.Tensor]:
        self.__check_initialized()
        eigvecs, eigval, nystrom_indices, sigma = ncut_click_prompt(
            self._features,
            fg_indices,
            bg_indices,
            return_indices_and_sigma=True,
            click_weight=click_weight,
            **kwargs,
        )

        eigvecs = kway_ncut(eigvecs, device=self.device)
        R = axis_align(eigvecs, device=self.device)
        kway_eigvecs = chunked_matmul(eigvecs, R, device=self.device, large_device=eigvecs.device)

        # find which cluster is the foreground
        fg_eigvecs = kway_eigvecs[fg_indices]
        fg_idx = fg_eigvecs.mean(0).argmax().item()
        bg_idx = 1 if fg_idx == 0 else 0

        # discretize the eigvecs
        mask = kway_eigvecs.argmax(dim=-1) == fg_idx
        heatmap = kway_eigvecs[:, fg_idx] - kway_eigvecs[:, bg_idx]

        # save for inference use
        self._nystrom_indices = nystrom_indices
        self._sigma = sigma
        self._click_eigvecs = eigvecs
        self._R = R
        self._fg_idx = fg_idx
        self._bg_idx = bg_idx

        return mask, heatmap

    def inference_new_features(self, new_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        self.__check_initialized()
        if not hasattr(self, "_nystrom_indices") or len(self._nystrom_indices) == 0:
            raise NotInitializedError("please call predict_clicks() before inference_new_features()")

        nystrom_X = self._features[self._nystrom_indices]
        nystrom_out = self._click_eigvecs[self._nystrom_indices]
        eigvecs = nystrom_propagate(nystrom_out, new_features, nystrom_X, device=self.device)
        eigvecs = chunked_matmul(eigvecs, self._R, device=self.device, large_device=eigvecs.device)
        mask = eigvecs.argmax(dim=-1) == self._fg_idx
        heatmap = eigvecs[:, self._fg_idx] - eigvecs[:, self._bg_idx]
        return mask, heatmap

    def get_color_palette(self, n_eig: int = 50) -> torch.Tensor:
        cache_hit = hasattr(self, '_color_palette') and len(self._color_palette) > 0
        if not cache_hit:
            self.refresh_color_palette(n_eig)
        return self._color_palette

    def refresh_color_palette(self, n_eig: int = 50) -> None:
        self.__check_initialized()
        if self.color_method == 'mspace':
            try:
                self._color_palette = mspace_color(self._features[:, :])
                assert not torch.isnan(self._color_palette).any()
            except Exception as e:
                warnings.warn(f"Error in mspace_color: {e}, using umap instead")
                self._color_palette = umap_color(self._eigvecs[:, :n_eig])
        elif self.color_method == 'tsne':
            self._color_palette = tsne_color(self._eigvecs[:, :n_eig])
        elif self.color_method == 'umap':
            self._color_palette = umap_color(self._eigvecs[:, :n_eig])
        else:
            raise ValueError(f"Invalid color method: {self.color_method}")

    def inference_new_color_palette(self, new_image: torch.Tensor) -> torch.Tensor:
        ... # TODO: implement this

    def __check_initialized(self) -> None:
        if not self._initialized or not hasattr(self, '_features') or \
            not hasattr(self, '_eigvecs'):
            raise NotInitializedError("Not initialized, please call initialize() first")

    def to(self, device: Union[str, torch.device]):
        self.device = device
        return self

ncut_pytorch.predictor.vision_predictor.NcutVisionPredictor

Source code in ncut_pytorch/predictor/vision_predictor.py
class NcutVisionPredictor:
    _initialized: bool = False

    def __init__(self,
                 model: nn.Module,
                 transform: transforms.Compose,
                 batch_size: int):
        self.model = model
        self.transform = transform

        self.batch_size = batch_size

        self._images: List[Image.Image]
        self._image_whs: List[Tuple[int, int]]
        self._feat_hws: Tuple[int, int]

        self.predictor = NcutPredictor()

    def set_images(self,
                   images: List[Image.Image],
                   n_segments: List[int] = (5, 25, 50, 100, 250)):
        """
        set the images and save its features in the cache.

        Args:
            images (List[Image.Image]): List of images to set.
            n_segments (List[int], optional): Number of segments to cache.
                n_segments is showed in the preview function.
        """
        features = self.forward_model(images)  # (b, c, h, w)
        self._images = images
        self._image_whs = np.array([image.size for image in images])
        self._feat_hws = (features.shape[2], features.shape[3])

        flat_features = features.permute(0, 2, 3, 1).reshape(-1, features.shape[1])
        self.predictor.initialize(flat_features, n_segments)
        self._initialized = True

    @torch.inference_mode()
    def forward_model(self, images: List[Image.Image]) -> torch.Tensor:
        device = next(self.model.parameters()).device
        all_features = []
        for i in range(0, len(images), self.batch_size):
            batch_images = images[i:i + self.batch_size]
            transformed_images = torch.stack([self.transform(image) for image in batch_images])
            transformed_images = transformed_images.to(device)
            try:
                with torch.autocast(device_type=device.type, enabled=True):
                    features = self.model(transformed_images)
            except RuntimeError:  # old torch version
                features = self.model(transformed_images)
            features = features.to('cpu')
            all_features.append(features)
        return torch.cat(all_features, dim=0)

    def generate(self, n_segment: int, n_eig: int = 10) -> torch.Tensor:
        """
        generate the cluster assignment for the images.

        Args:
            n_cluster (int): Number of clusters to generate.

        Returns:
            torch.Tensor: Cluster assignment for the images. (b, h, w)
        """
        self.__check_initialized()
        cluster_assignment = self.predictor.get_n_segments(n_segment)
        b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]
        cluster_assignment = cluster_assignment.reshape(b, h, w)
        return cluster_assignment

    def preview(self,
                point_coord: Tuple[int, int],
                image_indices: int) -> List[torch.Tensor]:
        """
        preview the hierarchy cluster assignment for the images.

        Args:
            point_coord (Tuple[int, int]): The coordinate of the point to preview, in original image resolution.
            image_indices (int): The index of the image to preview, corresponds to the point_coord.

        Returns:
            List[torch.Tensor]: List of masks for each hierarchy level. each mask is (b, h, w)
        """
        self.__check_initialized()
        b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]

        point_index = self._image_xy_to_tensor_index(self._image_whs, 
                                                     self._feat_hws,
                                                     np.array([point_coord]), 
                                                     np.array([image_indices])
                                                     )[0]
        masks = self.predictor.get_hierarchy_masks(point_index)
        masks = [mask.reshape(b, h, w) for mask in masks]
        return masks

    def summary(self,
                n_segments: List[int] = (5, 25, 50, 100, 250),
                n_eig: int = 10,
                draw_border: bool = True,
                ) -> List[torch.Tensor]:
        """
        summary the cluster assignment for the images.

        Args:
            n_segments (List[int]): Number of segments to summary.
        """
        self.__check_initialized()
        display_hw = 512

        colors = []
        colors.append(self._images)
        for n_segment in n_segments:
            cluster_assignment = self.generate(n_segment, n_eig=n_eig)
            color = self.color_discrete(cluster_assignment, draw_border=draw_border)
            colors.append(color)
        color = self.color_continues()
        colors.append(color)

        # make a grid of images
        n_rows = len(self._images)
        n_cols = len(n_segments) + 2
        grid_image = Image.new('RGB', size=(n_cols * display_hw, n_rows * display_hw))
        for i in range(n_rows):
            for j in range(n_cols):
                img = colors[j][i]
                img = Image.fromarray(np.array(img))
                img = img.resize((display_hw, display_hw), Image.Resampling.NEAREST)
                grid_image.paste(img, box=(j * display_hw, i * display_hw))
        return grid_image

    def predict(self,
                point_coords: np.ndarray,
                point_labels: np.ndarray,
                image_indices: np.ndarray,
                click_weight: float = 0.5,
                **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        predict the mask and heatmap for the images, based on the clicks.

        Args:
            point_coords (np.ndarray): The coordinates of the points to predict, in original image resolution. (n, 2)
            point_labels (np.ndarray): The labels of the points to predict, can be 1 (positive) or 0 (negative). (n, )
            image_indices (np.ndarray): The indices of the images corresponde to the point_coords. (n, )
            click_weight (float, optional): The weight of the click. Defaults to 0.5.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Mask and heatmap for the images. (b, h, w)
        """
        self.__check_initialized()
        fg_indices = self._image_xy_to_tensor_index(self._image_whs, self._feat_hws,
                                                    point_coords[point_labels == 1],
                                                    image_indices[point_labels == 1])
        bg_indices = self._image_xy_to_tensor_index(self._image_whs, self._feat_hws,
                                                    point_coords[point_labels == 0],
                                                    image_indices[point_labels == 0])
        b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]

        mask, heatmap = self.predictor.predict_clicks(fg_indices, bg_indices, click_weight, **kwargs)

        mask = mask.reshape(b, h, w)
        heatmap = heatmap.reshape(b, h, w)

        return mask, heatmap

    def inference(self, images: List[Image.Image]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        inference the mask and heatmap for new images, based on the saved states in the predict function.

        Args:
            images (List[Image.Image]): List of images to inference.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Mask and heatmap for the images. (b, h, w)
        """
        self.__check_initialized()
        new_features = self.forward_model(images)
        b, c, h, w = new_features.shape
        new_features = new_features.permute(0, 2, 3, 1).reshape(-1, c)

        mask, heatmap = self.predictor.inference_new_features(new_features)

        mask = mask.reshape(b, h, w)
        heatmap = heatmap.reshape(b, h, w)
        return mask, heatmap

    def color_continues(self) -> np.ndarray:
        """
        color the features by continues mspace color palette.

        Returns:
            np.ndarray: RGB image. (b, h, w, 3)
        """
        self.__check_initialized()
        b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]
        color_palette = self.predictor.get_color_palette()
        rgb = color_palette.reshape(b, h, w, 3)
        rgb = (rgb * 255).to(torch.uint8).cpu().numpy()
        return rgb

    def color_discrete(self, 
                       cluster_assignment: torch.Tensor,
                       draw_border: bool = True,
                       ) -> List[Image.Image]:
        """
        color the features by discrete mspace color palette.

        Args:
            cluster_assignment (torch.Tensor): Cluster assignment for the images. (b, h, w)
            draw_boundaries (bool, optional): Whether to draw boundaries. Defaults to True.

        Returns:
            List[Image.Image]: List of RGB images.
        """
        self.__check_initialized()
        b, h, w = cluster_assignment.shape
        lowres_cluster_assignment = cluster_assignment.cpu().numpy()
        color_palette = self.predictor.get_color_palette()
        n_cluster = int(cluster_assignment.max().item()) + 1
        flat_cluster_assignment = cluster_assignment.flatten()
        discrete_rgb = np.zeros((b * h * w, 3), dtype=np.uint8)
        for i in range(n_cluster):
            mask = flat_cluster_assignment == i
            if mask.sum() == 0:
                continue
            color = color_palette[mask].mean(0)
            color = (color * 255).cpu().numpy()
            color = np.clip(color, 0, 255).astype(np.uint8)
            discrete_rgb[mask.cpu().numpy()] = color
        discrete_rgb = discrete_rgb.reshape(b, h, w, 3)

        # convert to PIL image and resize to the original size
        pil_images = []
        for i in range(b):
            img = Image.fromarray(discrete_rgb[i])
            img = img.resize(self._images[i].size, Image.Resampling.NEAREST)
            if draw_border:
                img = self._draw_segments_border(img, lowres_cluster_assignment[i])
            pil_images.append(img)

        return pil_images

    def refresh_color_palette(self):
        self.predictor.refresh_color_palette()

    @staticmethod
    def _draw_segments_border(
        img: Image.Image,
        lowres_cluster_assignment: np.ndarray,
        min_area_ratio: float = 0.0005,
    ) -> Image.Image:
        drawing = np.array(img)
        labels = NcutVisionPredictor._resize_label_map(
            lowres_cluster_assignment,
            drawing.shape[:2],
        )
        keep_mask_lowres = NcutVisionPredictor._get_component_keep_mask_lowres(
            lowres_cluster_assignment,
            min_area_ratio=min_area_ratio,
        )
        keep_mask = NcutVisionPredictor._resize_bool_mask(
            keep_mask_lowres,
            drawing.shape[:2],
        )

        boundary = np.zeros(labels.shape, dtype=bool)

        horizontal_diff = labels[:, :-1] != labels[:, 1:]
        horizontal_diff &= keep_mask[:, :-1] & keep_mask[:, 1:]
        boundary[:, :-1] |= horizontal_diff

        vertical_diff = labels[:-1, :] != labels[1:, :]
        vertical_diff &= keep_mask[:-1, :] & keep_mask[1:, :]
        boundary[:-1, :] |= vertical_diff

        boundary[0, :] |= keep_mask[0, :]
        boundary[-1, :] |= keep_mask[-1, :]
        boundary[:, 0] |= keep_mask[:, 0]
        boundary[:, -1] |= keep_mask[:, -1]

        drawing[boundary] = 0
        return Image.fromarray(drawing)

    @staticmethod
    def _resize_label_map(labels: np.ndarray, image_hw: Tuple[int, int]) -> np.ndarray:
        target_h, target_w = image_hw
        if labels.shape == (target_h, target_w):
            return labels.astype(np.int32, copy=False)

        label_img = Image.fromarray(labels.astype(np.int32))
        label_img = label_img.resize((target_w, target_h), Image.Resampling.NEAREST)
        return np.array(label_img, dtype=np.int32)

    @staticmethod
    def _resize_bool_mask(mask: np.ndarray, image_hw: Tuple[int, int]) -> np.ndarray:
        target_h, target_w = image_hw
        if mask.shape == (target_h, target_w):
            return mask.astype(bool, copy=False)

        mask_img = Image.fromarray(mask.astype(np.uint8) * 255)
        mask_img = mask_img.resize((target_w, target_h), Image.Resampling.NEAREST)
        return np.array(mask_img, dtype=np.uint8) > 0

    @staticmethod
    def _get_component_keep_mask_lowres(
        lowres_cluster_assignment: np.ndarray,
        min_area_ratio: float = 0.0005,
    ) -> np.ndarray:
        labels = np.asarray(lowres_cluster_assignment, dtype=np.int32)
        keep_mask = np.zeros(labels.shape, dtype=bool)
        visited = np.zeros(labels.shape, dtype=bool)
        area_threshold = labels.size * min_area_ratio
        height, width = labels.shape

        for row in range(height):
            for col in range(width):
                if visited[row, col]:
                    continue

                label = labels[row, col]
                component = []
                stack = [(row, col)]
                visited[row, col] = True

                while stack:
                    y, x = stack.pop()
                    component.append((y, x))

                    row_start = max(0, y - 1)
                    row_end = min(height, y + 2)
                    col_start = max(0, x - 1)
                    col_end = min(width, x + 2)
                    for ny in range(row_start, row_end):
                        for nx in range(col_start, col_end):
                            if (ny == y and nx == x) or visited[ny, nx]:
                                continue
                            if labels[ny, nx] != label:
                                continue
                            visited[ny, nx] = True
                            stack.append((ny, nx))

                if len(component) < area_threshold:
                    continue

                for y, x in component:
                    keep_mask[y, x] = True

        return keep_mask

    @staticmethod
    def _image_xy_to_tensor_index(image_whs: np.array,
                                  feat_hws: np.array,
                                  point_coords: np.ndarray,
                                  image_indices: np.ndarray) -> np.ndarray:
        """
        Convert image xy coordinates to tensor index.
        Args:
            image_whs: List of image width and height, (n_images, 2)
            feat_hws: Feature width and height, (2, )
            point_coords: Point coordinates, (n_points, 2)
            image_indices: Image indices for each point, (n_points, )
        Returns:
            Point indices
        """
        if len(point_coords) == 0:
            return np.array([], dtype=np.int64)

        wh = image_whs[image_indices]
        point_coords = point_coords / wh

        point_coords = np.flip(point_coords, axis=1)  # (x, y) -> (y, x)

        point_coords = point_coords * feat_hws
        point_coords = point_coords.astype(np.int64)

        point_indices = point_coords[:, 0] * feat_hws[0] + point_coords[:, 1]

        offset_perimg = np.prod(feat_hws)
        offsets = image_indices * offset_perimg
        point_indices = point_indices + offsets
        point_indices = point_indices.astype(np.int64)
        return point_indices

    def __check_initialized(self):
        if not self._initialized:
            raise NotInitializedError("Not initialized, please call set_images() first")

    def to(self, device: Union[str, torch.device]):
        self.model = self.model.to(device)
        self.predictor = self.predictor.to(device)
        return self

color_continues()

color the features by continues mspace color palette.

Returns:

Type Description
ndarray

np.ndarray: RGB image. (b, h, w, 3)

Source code in ncut_pytorch/predictor/vision_predictor.py
def color_continues(self) -> np.ndarray:
    """
    color the features by continues mspace color palette.

    Returns:
        np.ndarray: RGB image. (b, h, w, 3)
    """
    self.__check_initialized()
    b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]
    color_palette = self.predictor.get_color_palette()
    rgb = color_palette.reshape(b, h, w, 3)
    rgb = (rgb * 255).to(torch.uint8).cpu().numpy()
    return rgb

color_discrete(cluster_assignment, draw_border=True)

color the features by discrete mspace color palette.

Parameters:

Name Type Description Default
cluster_assignment Tensor

Cluster assignment for the images. (b, h, w)

required
draw_boundaries bool

Whether to draw boundaries. Defaults to True.

required

Returns:

Type Description
List[Image]

List[Image.Image]: List of RGB images.

Source code in ncut_pytorch/predictor/vision_predictor.py
def color_discrete(self, 
                   cluster_assignment: torch.Tensor,
                   draw_border: bool = True,
                   ) -> List[Image.Image]:
    """
    color the features by discrete mspace color palette.

    Args:
        cluster_assignment (torch.Tensor): Cluster assignment for the images. (b, h, w)
        draw_boundaries (bool, optional): Whether to draw boundaries. Defaults to True.

    Returns:
        List[Image.Image]: List of RGB images.
    """
    self.__check_initialized()
    b, h, w = cluster_assignment.shape
    lowres_cluster_assignment = cluster_assignment.cpu().numpy()
    color_palette = self.predictor.get_color_palette()
    n_cluster = int(cluster_assignment.max().item()) + 1
    flat_cluster_assignment = cluster_assignment.flatten()
    discrete_rgb = np.zeros((b * h * w, 3), dtype=np.uint8)
    for i in range(n_cluster):
        mask = flat_cluster_assignment == i
        if mask.sum() == 0:
            continue
        color = color_palette[mask].mean(0)
        color = (color * 255).cpu().numpy()
        color = np.clip(color, 0, 255).astype(np.uint8)
        discrete_rgb[mask.cpu().numpy()] = color
    discrete_rgb = discrete_rgb.reshape(b, h, w, 3)

    # convert to PIL image and resize to the original size
    pil_images = []
    for i in range(b):
        img = Image.fromarray(discrete_rgb[i])
        img = img.resize(self._images[i].size, Image.Resampling.NEAREST)
        if draw_border:
            img = self._draw_segments_border(img, lowres_cluster_assignment[i])
        pil_images.append(img)

    return pil_images

generate(n_segment, n_eig=10)

generate the cluster assignment for the images.

Parameters:

Name Type Description Default
n_cluster int

Number of clusters to generate.

required

Returns:

Type Description
Tensor

torch.Tensor: Cluster assignment for the images. (b, h, w)

Source code in ncut_pytorch/predictor/vision_predictor.py
def generate(self, n_segment: int, n_eig: int = 10) -> torch.Tensor:
    """
    generate the cluster assignment for the images.

    Args:
        n_cluster (int): Number of clusters to generate.

    Returns:
        torch.Tensor: Cluster assignment for the images. (b, h, w)
    """
    self.__check_initialized()
    cluster_assignment = self.predictor.get_n_segments(n_segment)
    b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]
    cluster_assignment = cluster_assignment.reshape(b, h, w)
    return cluster_assignment

inference(images)

inference the mask and heatmap for new images, based on the saved states in the predict function.

Parameters:

Name Type Description Default
images List[Image]

List of images to inference.

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: Mask and heatmap for the images. (b, h, w)

Source code in ncut_pytorch/predictor/vision_predictor.py
def inference(self, images: List[Image.Image]) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    inference the mask and heatmap for new images, based on the saved states in the predict function.

    Args:
        images (List[Image.Image]): List of images to inference.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Mask and heatmap for the images. (b, h, w)
    """
    self.__check_initialized()
    new_features = self.forward_model(images)
    b, c, h, w = new_features.shape
    new_features = new_features.permute(0, 2, 3, 1).reshape(-1, c)

    mask, heatmap = self.predictor.inference_new_features(new_features)

    mask = mask.reshape(b, h, w)
    heatmap = heatmap.reshape(b, h, w)
    return mask, heatmap

predict(point_coords, point_labels, image_indices, click_weight=0.5, **kwargs)

predict the mask and heatmap for the images, based on the clicks.

Parameters:

Name Type Description Default
point_coords ndarray

The coordinates of the points to predict, in original image resolution. (n, 2)

required
point_labels ndarray

The labels of the points to predict, can be 1 (positive) or 0 (negative). (n, )

required
image_indices ndarray

The indices of the images corresponde to the point_coords. (n, )

required
click_weight float

The weight of the click. Defaults to 0.5.

0.5

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: Mask and heatmap for the images. (b, h, w)

Source code in ncut_pytorch/predictor/vision_predictor.py
def predict(self,
            point_coords: np.ndarray,
            point_labels: np.ndarray,
            image_indices: np.ndarray,
            click_weight: float = 0.5,
            **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    predict the mask and heatmap for the images, based on the clicks.

    Args:
        point_coords (np.ndarray): The coordinates of the points to predict, in original image resolution. (n, 2)
        point_labels (np.ndarray): The labels of the points to predict, can be 1 (positive) or 0 (negative). (n, )
        image_indices (np.ndarray): The indices of the images corresponde to the point_coords. (n, )
        click_weight (float, optional): The weight of the click. Defaults to 0.5.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Mask and heatmap for the images. (b, h, w)
    """
    self.__check_initialized()
    fg_indices = self._image_xy_to_tensor_index(self._image_whs, self._feat_hws,
                                                point_coords[point_labels == 1],
                                                image_indices[point_labels == 1])
    bg_indices = self._image_xy_to_tensor_index(self._image_whs, self._feat_hws,
                                                point_coords[point_labels == 0],
                                                image_indices[point_labels == 0])
    b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]

    mask, heatmap = self.predictor.predict_clicks(fg_indices, bg_indices, click_weight, **kwargs)

    mask = mask.reshape(b, h, w)
    heatmap = heatmap.reshape(b, h, w)

    return mask, heatmap

preview(point_coord, image_indices)

preview the hierarchy cluster assignment for the images.

Parameters:

Name Type Description Default
point_coord Tuple[int, int]

The coordinate of the point to preview, in original image resolution.

required
image_indices int

The index of the image to preview, corresponds to the point_coord.

required

Returns:

Type Description
List[Tensor]

List[torch.Tensor]: List of masks for each hierarchy level. each mask is (b, h, w)

Source code in ncut_pytorch/predictor/vision_predictor.py
def preview(self,
            point_coord: Tuple[int, int],
            image_indices: int) -> List[torch.Tensor]:
    """
    preview the hierarchy cluster assignment for the images.

    Args:
        point_coord (Tuple[int, int]): The coordinate of the point to preview, in original image resolution.
        image_indices (int): The index of the image to preview, corresponds to the point_coord.

    Returns:
        List[torch.Tensor]: List of masks for each hierarchy level. each mask is (b, h, w)
    """
    self.__check_initialized()
    b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]

    point_index = self._image_xy_to_tensor_index(self._image_whs, 
                                                 self._feat_hws,
                                                 np.array([point_coord]), 
                                                 np.array([image_indices])
                                                 )[0]
    masks = self.predictor.get_hierarchy_masks(point_index)
    masks = [mask.reshape(b, h, w) for mask in masks]
    return masks

set_images(images, n_segments=(5, 25, 50, 100, 250))

set the images and save its features in the cache.

Parameters:

Name Type Description Default
images List[Image]

List of images to set.

required
n_segments List[int]

Number of segments to cache. n_segments is showed in the preview function.

(5, 25, 50, 100, 250)
Source code in ncut_pytorch/predictor/vision_predictor.py
def set_images(self,
               images: List[Image.Image],
               n_segments: List[int] = (5, 25, 50, 100, 250)):
    """
    set the images and save its features in the cache.

    Args:
        images (List[Image.Image]): List of images to set.
        n_segments (List[int], optional): Number of segments to cache.
            n_segments is showed in the preview function.
    """
    features = self.forward_model(images)  # (b, c, h, w)
    self._images = images
    self._image_whs = np.array([image.size for image in images])
    self._feat_hws = (features.shape[2], features.shape[3])

    flat_features = features.permute(0, 2, 3, 1).reshape(-1, features.shape[1])
    self.predictor.initialize(flat_features, n_segments)
    self._initialized = True

summary(n_segments=(5, 25, 50, 100, 250), n_eig=10, draw_border=True)

summary the cluster assignment for the images.

Parameters:

Name Type Description Default
n_segments List[int]

Number of segments to summary.

(5, 25, 50, 100, 250)
Source code in ncut_pytorch/predictor/vision_predictor.py
def summary(self,
            n_segments: List[int] = (5, 25, 50, 100, 250),
            n_eig: int = 10,
            draw_border: bool = True,
            ) -> List[torch.Tensor]:
    """
    summary the cluster assignment for the images.

    Args:
        n_segments (List[int]): Number of segments to summary.
    """
    self.__check_initialized()
    display_hw = 512

    colors = []
    colors.append(self._images)
    for n_segment in n_segments:
        cluster_assignment = self.generate(n_segment, n_eig=n_eig)
        color = self.color_discrete(cluster_assignment, draw_border=draw_border)
        colors.append(color)
    color = self.color_continues()
    colors.append(color)

    # make a grid of images
    n_rows = len(self._images)
    n_cols = len(n_segments) + 2
    grid_image = Image.new('RGB', size=(n_cols * display_hw, n_rows * display_hw))
    for i in range(n_rows):
        for j in range(n_cols):
            img = colors[j][i]
            img = Image.fromarray(np.array(img))
            img = img.resize((display_hw, display_hw), Image.Resampling.NEAREST)
            grid_image.paste(img, box=(j * display_hw, i * display_hw))
    return grid_image

ncut_pytorch.predictor.dino_predictor.NcutDinov3Predictor

Bases: NcutVisionPredictor

Source code in ncut_pytorch/predictor/dino_predictor.py
class NcutDinov3Predictor(NcutVisionPredictor):
    def __init__(self,
                 input_size: Tuple[int, int] = (2048, 2048),
                 model_cfg: str = "dinov3_vitl16",
                 batch_size: int = 8,
    ):
        model = Dinov3Backbone(model_cfg)
        transform = get_input_transform(resize=input_size)
        super().__init__(model, transform, batch_size)

ncut_pytorch.predictor.dino_predictor.NcutDinoPredictor

Bases: NcutVisionPredictor

Source code in ncut_pytorch/predictor/dino_predictor.py
class NcutDinoPredictor(NcutVisionPredictor):
    def __init__(self,
                 input_size: Tuple[int, int] = (512, 512),
                 dtype: torch.dtype = torch.float32,
                 batch_size: int = 8):
        model = LowResDINO(dtype=dtype)
        transform = get_input_transform(resize=input_size)
        super().__init__(model, transform, batch_size)

ncut_pytorch.predictor.dino_predictor.NcutDinoPredictorFeatUp

Bases: NcutVisionPredictor

Source code in ncut_pytorch/predictor/dino_predictor.py
class NcutDinoPredictorFeatUp(NcutVisionPredictor):
    def __init__(self,
                 input_size: Tuple[int, int] = (512, 512),
                 batch_size: int = 8):
        model = torch.hub.load("huzeyann/FeatUp", 'dino', use_norm=False)
        transform = get_input_transform(resize=input_size)
        super().__init__(model, transform, batch_size)

ncut_pytorch.predictor.dino_predictor.NcutDinoPredictorSR

Bases: NcutVisionPredictor

Source code in ncut_pytorch/predictor/dino_predictor.py
class NcutDinoPredictorSR(NcutVisionPredictor):
    def __init__(self,
                 input_size: int = 512,
                 dtype: torch.dtype = torch.float16,
                 batch_size: int = 1):
        model, transform = SUPER_RESOLUTION_MODELS[input_size](dtype=dtype)
        super().__init__(model, transform, batch_size)

Color Utilities

ncut_pytorch.color.coloring.mspace_color(X, q=0.95, n_eig_list=[4, 16, 64], n_dim=3, encoder_training_steps=3000, decoder_training_steps=0, progress_bar=False, **kwargs)

Returns:

Type Description
Tensor

Embedding in 2D, shape (n_samples, 2)

Tensor

RGB color for each data sample, shape (n_samples, 3)

Source code in ncut_pytorch/color/coloring.py
def mspace_color(
        X: torch.Tensor,
        q: float = 0.95,
        n_eig_list: Optional[List[int]] = [4, 16, 64],
        n_dim: int = 3,
        encoder_training_steps: int = 3000,
        decoder_training_steps: int = 0,
        progress_bar: bool = False,
        **kwargs: Any,
):
    """
    Returns:
        (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
        (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
    """
    from .mspace import mspace_viz_transform

    training_steps = kwargs.pop("training_steps", None)
    if training_steps is not None:
        encoder_training_steps = training_steps

    n_eig = kwargs.pop("n_eig", None)
    if n_eig is not None and "n_eig_list" not in kwargs:
        n_eig_list = [n_eig]

    low_dim_embedding = mspace_viz_transform(
        X=X,
        n_eig_list=n_eig_list,
        z_dim=n_dim,
        encoder_training_steps=encoder_training_steps,
        decoder_training_steps=decoder_training_steps,
        progress_bar=progress_bar,
        flag_loss_mode='z',
        flag_loss=1.0,
        recon_loss=0.001,
        zero_center_loss=0.001,
        repulsion_loss=0.001,
        **kwargs)

    rgb = rgb_from_nd_colormap(low_dim_embedding, q=q)

    return rgb

ncut_pytorch.color.coloring.tsne_color(X, num_sample=1000, perplexity=150, n_dim=3, metric='cosine', device=None, seed=None, q=0.95, knn=10, **kwargs)

Returns:

Type Description
Tensor

Embedding in 2D, shape (n_samples, 2)

Tensor

RGB color for each data sample, shape (n_samples, 3)

Source code in ncut_pytorch/color/coloring.py
def tsne_color(
        X: torch.Tensor,
        num_sample: int = 1000,
        perplexity: int = 150,
        n_dim: int = 3,
        metric: Literal["cosine", "euclidean"] = "cosine",
        device: str = None,
        seed: int = None,
        q: float = 0.95,
        knn: int = 10,
        **kwargs: Any,
):
    """
    Returns:
        (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
        (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
    """
    try:
        from sklearn.manifold import TSNE
    except ImportError:
        raise ImportError(
            "sklearn import failed, please install `pip install scikit-learn`"
        )
    num_sample = min(num_sample, X.shape[0])
    if perplexity > num_sample // 2:
        warnings.warn(
            f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}",
            stacklevel=2,
            category=UserWarning,
        )
        perplexity = num_sample // 2

    low_dim_embedding, rgb = _nystrom_dimension_reduction(
        X=X,
        num_sample=num_sample,
        metric=metric,
        rgb_func=rgb_from_nd_colormap,
        q=q, knn=knn,
        seed=seed, device=device,
        reduction=TSNE, reduction_dim=n_dim, reduction_kwargs={
            "perplexity": perplexity,
        },
    )

    return rgb

ncut_pytorch.color.coloring.umap_color(X, num_sample=1000, n_neighbors=150, min_dist=1.0, spread=10.0, n_dim=3, metric='cosine', device=None, seed=None, q=0.95, knn=10, **kwargs)

Returns:

Type Description
Tensor

Embedding in 2D, shape (n_samples, 2)

Tensor

RGB color for each data sample, shape (n_samples, 3)

Source code in ncut_pytorch/color/coloring.py
def umap_color(
        X: torch.Tensor,
        num_sample: int = 1000,
        n_neighbors: int = 150,
        min_dist: float = 1.0,
        spread: float = 10.0,
        n_dim: int = 3,
        metric: Literal["cosine", "euclidean"] = "cosine",
        device: str = None,
        seed: int = None,
        q: float = 0.95,
        knn: int = 10,
        **kwargs: Any,
):
    """
    Returns:
        (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
        (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
    """
    try:
        from umap import UMAP
    except ImportError:
        raise ImportError("umap import failed, please install `pip install umap-learn`")

    low_dim_embedding, rgb = _nystrom_dimension_reduction(
        X=X,
        num_sample=num_sample,
        metric=metric,
        rgb_func=rgb_from_nd_colormap,
        q=q, knn=knn,
        seed=seed, device=device,
        reduction=UMAP, reduction_dim=n_dim, reduction_kwargs={
            "n_neighbors": n_neighbors,
            "min_dist": min_dist,
            "spread": spread,
            "low_memory": False,
            "n_epochs": 200,
        },
    )

    return rgb

ncut_pytorch.color.coloring.rotate_rgb_cube(rgb, position=1)

rotate RGB cube to different position

Parameters:

Name Type Description Default
rgb Tensor

RGB color space [0, 1], shape (*, 3)

required
position int

position to rotate, 0, 1, 2, 3, 4, 5, 6

1

Returns:

Type Description

torch.Tensor: RGB color space, shape (n_samples, 3)

Source code in ncut_pytorch/color/coloring.py
def rotate_rgb_cube(rgb, position=1):
    """rotate RGB cube to different position

    Args:
        rgb (torch.Tensor): RGB color space [0, 1], shape (*, 3)
        position (int): position to rotate, 0, 1, 2, 3, 4, 5, 6

    Returns:
        torch.Tensor: RGB color space, shape (n_samples, 3)
    """
    assert position in range(0, 7), "position should be 0, 1, 2, 3, 4, 5, 6"
    rotation_matrix = torch.tensor(
        [
            [0, 1, 0],
            [0, 0, 1],
            [1, 0, 0],
        ]
    ).float()
    n_mul = position % 3
    rotation_matrix = torch.matrix_power(rotation_matrix, n_mul)
    rgb = rgb @ rotation_matrix
    if position > 3:
        rgb = 1 - rgb
    return rgb

Math & Utilities

ncut_pytorch.utils.math.rbf_affinity(X1, X2=None, sigma=1.0, zero_diag=False, gamma=None)

Computes RBF affinity matrix: W_ij = exp(-||x_i - x_j||^2 / (2 * sigma^2)).

Source code in ncut_pytorch/utils/math.py
def rbf_affinity(
    X1: torch.Tensor,          # [N,D]
    X2: torch.Tensor | None = None,  # [M,D]
    sigma: float = 1.0,
    zero_diag: bool = False,
    gamma: float | None = None,  # deprecated
) -> torch.Tensor:             # [N,M]
    """Computes RBF affinity matrix: W_ij = exp(-||x_i - x_j||^2 / (2 * sigma^2))."""
    sigma = sigma if gamma is None else check_gamma_deprecated(gamma)
    X2 = X1 if X2 is None else X2

    try:
        x1_sq = X1.pow(2).sum(dim=1, keepdim=True)
        if X2 is X1:
            dist2 = x1_sq + x1_sq.T
        else:
            x2_sq = X2.pow(2).sum(dim=1).unsqueeze(0)
            dist2 = x1_sq + x2_sq
        dist2.addmm_(X1, X2.T, beta=1.0, alpha=-2.0)
        dist2.clamp_min_(0)
    except RuntimeError:
        try:
            dist2 = torch.cdist(X1, X2, p=2).pow_(2)
        except NotImplementedError:
            dist2 = X1.unsqueeze(1) - X2.unsqueeze(0)
            dist2 = dist2.pow(2).sum(dim=-1)
    W = dist2.mul_(-0.5 / (sigma * sigma)).exp_()   # [N,M]
    if zero_diag and X1 is X2:
        W.fill_diagonal_(0.0)
    return W

ncut_pytorch.utils.math.cosine_affinity(X1, X2=None, sigma=1.0, repulse=False, zero_diag=False, gamma=None)

Computes cosine-based affinity matrix.

Source code in ncut_pytorch/utils/math.py
def cosine_affinity(
    X1: torch.Tensor,          # [N,D]
    X2: torch.Tensor | None = None,  # [M,D]
    sigma: float = 1.0,
    repulse: bool = False,
    zero_diag: bool = False,
    gamma: float | None = None,  # deprecated
) -> torch.Tensor:             # [N,M]
    """Computes cosine-based affinity matrix."""
    sigma = sigma if gamma is None else check_gamma_deprecated(gamma)
    X2 = X1 if X2 is None else X2

    X1_norm = torch.nn.functional.normalize(X1, p=2, dim=1, eps=1e-8)
    X2_norm = torch.nn.functional.normalize(X2, p=2, dim=1, eps=1e-8)
    S = torch.mm(X1_norm, X2_norm.T)
    num = S + 1 if repulse else S - 1
    W = torch.exp(- num**2 / (2.0 * sigma * sigma))
    if not repulse:
        W = W + 1e-3
    if zero_diag and X1 is X2:
        W = W.clone()
        W.fill_diagonal_(0.0)
    return W

ncut_pytorch.utils.math.quantile_normalize(x, q=0.95)

Normalizes each dimension of x to [0, 1] using quantiles, robust to outliers.

Source code in ncut_pytorch/utils/math.py
def quantile_normalize(
    x: torch.Tensor | np.ndarray,  # [n_samples, n_features]
    q: float = 0.95,
) -> torch.Tensor:             # [n_samples, n_features]
    """Normalizes each dimension of x to [0, 1] using quantiles, robust to outliers."""
    if isinstance(x, np.ndarray):
        x = torch.tensor(x)

    vmax, vmin = quantile_min_max(x, q, 1 - q)
    x = (x - vmin) / (vmax - vmin)
    return x.clamp(0, 1)

ncut_pytorch.utils.sigma.find_sigma_by_degree(X, quantile_sigma=0.25, affinity_fn=rbf_affinity, X2=None, init_sigma=0.5, r_tol=0.01, max_iter=100, n_sample=1000)

Find sigma after FPS-based downsampling for efficiency.

Source code in ncut_pytorch/utils/sigma.py
@torch.no_grad()
def find_sigma_by_degree(
    X: torch.Tensor,                    # [n_samples, n_features]
    quantile_sigma: float = 0.25,
    affinity_fn: callable = rbf_affinity,
    X2: torch.Tensor | None = None,
    init_sigma: float = 0.5,
    r_tol: float = 1e-2,
    max_iter: int = 100,
    n_sample: int = 1000,
) -> float:
    """Find sigma after FPS-based downsampling for efficiency."""
    indices = farthest_point_sampling(X, n_sample)
    return _find_sigma_by_degree(X[indices], quantile_sigma, affinity_fn, X2=X2, init_sigma=init_sigma, r_tol=r_tol, max_iter=max_iter)