From 10bfc7c45e499c5a56858426174cca59e51c23b3 Mon Sep 17 00:00:00 2001 From: yangnianzu0515 Date: Tue, 9 May 2023 16:34:54 +0800 Subject: [PATCH] fix center_shift=nan and add a dependency for installing from source --- README.md | 1 + kmeans_pytorch/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 771f627..07c4b9b 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ git clone https://github.com/subhadarship/kmeans_pytorch cd kmeans_pytorch pip install --editable . ``` +Installing from source requires 'numba' dependency. # CPU vs GPU see [`cpu_vs_gpu.ipynb`](https://github.com/subhadarship/kmeans_pytorch/blob/master/cpu_vs_gpu.ipynb) for a comparison between CPU and GPU diff --git a/kmeans_pytorch/__init__.py b/kmeans_pytorch/__init__.py index 1311f32..f6258bb 100644 --- a/kmeans_pytorch/__init__.py +++ b/kmeans_pytorch/__init__.py @@ -101,7 +101,8 @@ def kmeans( if selected.shape[0] == 0: selected = X[torch.randint(len(X), (1,))] - initial_state[index] = selected.mean(dim=0) + if torch.isnan(selected.mean(dim=0)).sum()==0: + initial_state[index] = selected.mean(dim=0) center_shift = torch.sum( torch.sqrt(