diff --git a/README.md b/README.md index 771f627..07c4b9b 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ git clone https://github.com/subhadarship/kmeans_pytorch cd kmeans_pytorch pip install --editable . ``` +Installing from source requires 'numba' dependency. # CPU vs GPU see [`cpu_vs_gpu.ipynb`](https://github.com/subhadarship/kmeans_pytorch/blob/master/cpu_vs_gpu.ipynb) for a comparison between CPU and GPU diff --git a/kmeans_pytorch/__init__.py b/kmeans_pytorch/__init__.py index 1311f32..f6258bb 100644 --- a/kmeans_pytorch/__init__.py +++ b/kmeans_pytorch/__init__.py @@ -101,7 +101,8 @@ def kmeans( if selected.shape[0] == 0: selected = X[torch.randint(len(X), (1,))] - initial_state[index] = selected.mean(dim=0) + if torch.isnan(selected.mean(dim=0)).sum()==0: + initial_state[index] = selected.mean(dim=0) center_shift = torch.sum( torch.sqrt(