Hi there,
First, thank you for the great library! I've encountered a limitation when using batch_kmeans_Euclid with input tensors where the feature dimension (last dim) is less than 16, which triggers a Triton compilation error.
Code
import torch
from flash_kmeans import batch_kmeans_Euclid
# Create input with feature dimension = 1 (causes error)
x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(x, n_clusters=1000, tol=1e-4, verbose=True)
Message
triton.compiler.errors.CompilationError: at 79:31: + pid_b * stride_c_b
+ k_offsets[None, :] * stride_c_k
+ offs_d[:, None] * stride_c_d
)
c_tile = tl.load(c_ptrs, mask=k_mask[None, :], other=0.0)
c_tile = c_tile
# Compute centroid squared norms (BLOCK_K,)
cent_sq = tl.sum(c_tile * c_tile, axis=0).to(tl.float32)
# Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile
cross = tl.dot(x_tile, c_tile).to(tl.float32) # float32
^
AssertionError('All values in both first input shape ([constexpr[32], constexpr[1]]) and second input shape ([constexpr[1], constexpr[32]]) must be >= 16!')
Would it be possible to fix this issue to support feature dimensions smaller than 16?
Hi there,
First, thank you for the great library! I've encountered a limitation when using batch_kmeans_Euclid with input tensors where the feature dimension (last dim) is less than 16, which triggers a Triton compilation error.
Code
Message
Would it be possible to fix this issue to support feature dimensions smaller than 16?