|
24 | 24 | # limitations under the License. |
25 | 25 | """Inference-only Qwen3VL model compatible with HuggingFace weights.""" |
26 | 26 |
|
27 | | -from collections.abc import Callable, Iterable, Mapping, Sequence |
| 27 | +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence |
28 | 28 | from functools import partial |
29 | 29 | from itertools import islice |
30 | 30 | from typing import Any |
@@ -1412,72 +1412,47 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: |
1412 | 1412 | ) |
1413 | 1413 | return mm_input_by_modality |
1414 | 1414 |
|
| 1415 | + def iter_mm_grid_hw( |
| 1416 | + self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] |
| 1417 | + ) -> Iterator[tuple[int, int, int]]: |
| 1418 | + video_token_id = self.config.video_token_id |
| 1419 | + spatial_merge_size = self.config.vision_config.spatial_merge_size |
| 1420 | + for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): |
| 1421 | + offset = mm_feature.mm_position.offset |
| 1422 | + if mm_feature.modality == "image": |
| 1423 | + t, h, w = mm_feature.data["image_grid_thw"].data.tolist() |
| 1424 | + assert t == 1, f"Image must have 1 frame, got {t}" |
| 1425 | + yield offset, h // spatial_merge_size, w // spatial_merge_size |
| 1426 | + elif mm_feature.modality == "video": |
| 1427 | + t, h, w = mm_feature.data["video_grid_thw"].data.tolist() |
| 1428 | + llm_grid_h = h // spatial_merge_size |
| 1429 | + llm_grid_w = w // spatial_merge_size |
| 1430 | + for _ in range(t): |
| 1431 | + offset = input_tokens.index(video_token_id, offset) |
| 1432 | + yield offset, llm_grid_h, llm_grid_w |
| 1433 | + offset += llm_grid_h * llm_grid_w |
| 1434 | + else: |
| 1435 | + raise ValueError(f"Unsupported modality: {mm_feature.modality}") |
| 1436 | + |
1415 | 1437 | def get_mrope_input_positions( |
1416 | 1438 | self, |
1417 | 1439 | input_tokens: list[int], |
1418 | 1440 | mm_features: list[MultiModalFeatureSpec], |
1419 | 1441 | ) -> tuple[torch.Tensor, int]: |
1420 | | - kwargs = MultiModalFeatureSpec.gather_kwargs( |
1421 | | - mm_features, |
1422 | | - {"image_grid_thw", "video_grid_thw"}, |
1423 | | - ) |
1424 | | - image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] |
1425 | | - video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] |
1426 | | - |
1427 | | - video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] |
1428 | | - |
1429 | | - hf_config = self.config |
1430 | | - image_token_id = hf_config.image_token_id |
1431 | | - video_token_id = hf_config.video_token_id |
1432 | | - vision_start_token_id = hf_config.vision_start_token_id |
1433 | | - spatial_merge_size = hf_config.vision_config.spatial_merge_size |
1434 | | - |
1435 | | - input_tokens_array = np.array(input_tokens) |
1436 | | - vision_start_mask = input_tokens_array == vision_start_token_id |
1437 | | - vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1] |
1438 | | - image_nums = np.count_nonzero(vision_tokens == image_token_id) |
1439 | | - video_nums = np.count_nonzero(vision_tokens == video_token_id) |
1440 | | - llm_pos_ids_list: list = [] |
1441 | | - |
| 1442 | + llm_pos_ids_list = [] |
1442 | 1443 | st = 0 |
1443 | | - remain_images, remain_videos = image_nums, video_nums |
1444 | | - |
1445 | | - image_index, video_index = 0, 0 |
1446 | | - for _ in range(image_nums + video_nums): |
1447 | | - if image_token_id in input_tokens and remain_images > 0: |
1448 | | - ed_image = input_tokens.index(image_token_id, st) |
1449 | | - else: |
1450 | | - ed_image = len(input_tokens) + 1 |
1451 | | - if video_token_id in input_tokens and remain_videos > 0: |
1452 | | - ed_video = input_tokens.index(video_token_id, st) |
1453 | | - else: |
1454 | | - ed_video = len(input_tokens) + 1 |
1455 | | - if ed_image < ed_video: |
1456 | | - t, h, w = image_grid_thw[image_index] |
1457 | | - image_index += 1 |
1458 | | - remain_images -= 1 |
1459 | | - ed = ed_image |
1460 | | - else: |
1461 | | - t, h, w = video_grid_thw[video_index] |
1462 | | - video_index += 1 |
1463 | | - remain_videos -= 1 |
1464 | | - ed = ed_video |
1465 | | - |
1466 | | - llm_grid_t, llm_grid_h, llm_grid_w = ( |
1467 | | - t, |
1468 | | - h // spatial_merge_size, |
1469 | | - w // spatial_merge_size, |
1470 | | - ) |
1471 | | - text_len = ed - st |
1472 | | - |
| 1444 | + for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( |
| 1445 | + input_tokens, mm_features |
| 1446 | + ): |
| 1447 | + text_len = offset - st |
1473 | 1448 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
1474 | 1449 | llm_pos_ids_list.append( |
1475 | 1450 | np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx |
1476 | 1451 | ) |
1477 | 1452 |
|
1478 | | - grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)) |
1479 | | - llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx) |
1480 | | - st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
| 1453 | + grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) |
| 1454 | + llm_pos_ids_list.append(grid_indices + text_len + st_idx) |
| 1455 | + st = offset + llm_grid_h * llm_grid_w |
1481 | 1456 |
|
1482 | 1457 | if st < len(input_tokens): |
1483 | 1458 | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
|
0 commit comments