Skip to content

Commit 745644f

Browse files
msaroufimfacebook-github-bot
authored andcommitted
FIX SAM for bfloat16 (#1764)
Summary: Ok this was kinda annoying Basically the SAM codebase had a few places where it hardcodes `torch.float32` such that even if you convert the model to `torch.bfloat16` a few parts of the model won't be and will have type mismatch errors - this fixes the problem cpuhrsch desertfire - idk enough about floats and why there isn't some type promotion rule for bfloat16 I wonder whether we should add tests for multiple dtypes in torchbench to make checking for this kind of issue more robust especially now that bfloat16 seems to be the default for dynamo xuzhao9 ## Logs ``` FAILED (errors=1) (sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_sam_eval_cuda" E ====================================================================== ERROR: test_sam_eval_cuda (__main__.TestBenchmark) ---------------------------------------------------------------------- components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last): File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 482, in _run_block exec( # noqa: P204 File "<subprocess-worker>", line 2, in <module> File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 280, in invoke out = self.eval() File "/home/ubuntu/benchmark/torchbenchmark/models/sam/__init__.py", line 65, in eval masks, scores, logits = predictor.predict( File "/home/ubuntu/benchmark/torchbenchmark/models/sam/predictor.py", line 164, in predict low_res_masks_np = low_res_masks[0].detach().cpu().numpy() TypeError: Got unsupported ScalarType BFloat16 working_dir: /tmp/tmpg5de41du stdout: [2023-07-13] 01:57:38.499061: TIMER_SUBPROCESS_BEGIN_EXEC [2023-07-13] 01:57:39.002078: TIMER_SUBPROCESS_FAILED [2023-07-13] 01:57:39.002141: TIMER_SUBPROCESS_FINISHED [2023-07-13] 01:57:39.002153: TIMER_SUBPROCESS_BEGIN_READ stderr: The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/ubuntu/benchmark/test.py", line 104, in eval_fn task.invoke() File "/home/ubuntu/benchmark/torchbenchmark/__init__.py", line 402, in invoke self.worker.run(""" File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_worker.py", line 155, in run self._run(snippet) File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_worker.py", line 320, in _run subprocess_rpc.SerializedException.raise_from( File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 458, in raise_from raise e from ChildTraceException(traceback_str) TypeError: Got unsupported ScalarType BFloat16 ---------------------------------------------------------------------- Ran 1 test in 7.814s FAILED (errors=1) (sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_sam_eval_cuda" . ---------------------------------------------------------------------- Ran 1 test in 8.315s OK ``` Pull Request resolved: #1764 Reviewed By: drisspg, cpuhrsch Differential Revision: D47441873 Pulled By: msaroufim fbshipit-source-id: a60880fd7c0826cfd469ace39d76894469ca0e5e
1 parent 2ea018e commit 745644f

File tree

4 files changed

+8
-3
lines changed

4 files changed

+8
-3
lines changed

torchbenchmark/models/sam/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def get_module(self):
4343
]
4444

4545
multimask_output = False
46-
4746
return self.model, (example_input, multimask_output)
4847

4948
def train(self):
@@ -57,6 +56,9 @@ def train(self):
5756
return NotImplementedError(error_msg)
5857

5958
def eval(self):
59+
# To test for bfloat16 uncomment the below line
60+
# predictor = SamPredictor(self.model.to(dtype=torch.bfloat16))
61+
6062
predictor = SamPredictor(self.model)
6163

6264
predictor.set_image(self.image)

torchbenchmark/models/sam/mask_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def predict_masks(
129129
b, c, h, w = src.shape
130130

131131
# Run the transformer
132+
tokens = tokens.to(src.dtype)
132133
hs, src = self.transformer(src, pos_src, tokens)
133134
iou_token_out = hs[:, 0, :]
134135
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

torchbenchmark/models/sam/predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def predict(
160160
)
161161

162162
masks_np = masks[0].detach().cpu().numpy()
163-
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
164-
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
163+
iou_predictions_np = iou_predictions[0].to(torch.float32).detach().cpu().numpy()
164+
low_res_masks_np = low_res_masks[0].to(torch.float32).detach().cpu().numpy()
165165
return masks_np, iou_predictions_np, low_res_masks_np
166166

167167
@torch.no_grad()

torchbenchmark/models/sam/prompt_encoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186186
"""Positionally encode points that are normalized to [0,1]."""
187187
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188188
coords = 2 * coords - 1
189+
coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
190+
189191
coords = coords @ self.positional_encoding_gaussian_matrix
190192
coords = 2 * np.pi * coords
191193
# outputs d_1 x ... x d_n x C shape

0 commit comments

Comments
 (0)