Skip to content

Triton Compilation Error: Matrix multiplication assertion fails for small dimensions (<16) in k-means kernel #1

@ljynlp

Description

@ljynlp

Hi there,

First, thank you for the great library! I've encountered a limitation when using batch_kmeans_Euclid with input tensors where the feature dimension (last dim) is less than 16, which triggers a Triton compilation error.

Code

import torch
from flash_kmeans import batch_kmeans_Euclid

# Create input with feature dimension = 1 (causes error)
x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(x, n_clusters=1000, tol=1e-4, verbose=True)

Message

triton.compiler.errors.CompilationError: at 79:31:            + pid_b * stride_c_b
            + k_offsets[None, :] * stride_c_k
            + offs_d[:, None] * stride_c_d
        )
        c_tile = tl.load(c_ptrs, mask=k_mask[None, :], other=0.0)
        c_tile = c_tile

        # Compute centroid squared norms (BLOCK_K,)
        cent_sq = tl.sum(c_tile * c_tile, axis=0).to(tl.float32)

        # Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile
        cross = tl.dot(x_tile, c_tile).to(tl.float32)  # float32
                               ^
AssertionError('All values in both first input shape ([constexpr[32], constexpr[1]]) and second input shape ([constexpr[1], constexpr[32]]) must be >= 16!')

Would it be possible to fix this issue to support feature dimensions smaller than 16?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions