Skip to content

Commit 0233d1f

Browse files
committed
kcz/support-for-video-in-benchmark
1 parent b7a5a80 commit 0233d1f

File tree

5 files changed

+158
-75
lines changed

5 files changed

+158
-75
lines changed

tools/llm_bench/benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def get_argprser():
229229
help="Path to .bin or .pt file with speaker embeddings for text to speech scenarios")
230230
parser.add_argument("--vocoder_path", type=str, default=None,
231231
help="Path to vocoder for text to speech scenarios")
232+
parser.add_argument("-vf", "--video_frames", type=int, default=None,
233+
help="controler of video frames to process")
232234
return parser.parse_args()
233235

234236

@@ -315,7 +317,8 @@ def main():
315317
args.num_iters, memory_data_collector)
316318
else:
317319
iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case'].task](
318-
model_path, framework, args.device, model_args, args.num_iters, memory_data_collector)
320+
model_path, framework, args.device, model_args, args.num_iters,
321+
memory_data_collector, args.video_frames)
319322
if args.report is not None or args.report_json is not None:
320323
model_precision = ''
321324
if framework == 'ov':

tools/llm_bench/llm_bench_utils/parse_json_data.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,78 +2,67 @@
22
# Copyright (C) 2023-2025 Intel Corporation
33
# SPDX-License-Identifier: Apache-2.0
44

5+
def create_base_prompt(json_data, key='prompt'):
6+
prompt_data = {}
7+
if key in json_data:
8+
if json_data[key] != "":
9+
prompt_data[key] = json_data[key]
10+
else:
11+
raise RuntimeError(f"== {key} should not be empty string ==")
12+
else:
13+
raise RuntimeError(f"== key word '{key}' does not exist ==")
14+
return prompt_data
15+
516

617
def parse_text_json_data(json_data_list):
718
text_param_list = []
819
for json_data in json_data_list:
9-
if 'prompt' in json_data:
10-
if json_data['prompt'] != '':
11-
text_param_list.append(json_data['prompt'])
12-
else:
13-
raise RuntimeError('== prompt should not be empty string ==')
14-
else:
15-
raise RuntimeError('== key word "prompt" does not exist ==')
20+
prompt_data = create_base_prompt(json_data)
21+
text_param_list.append(prompt_data["prompt"])
1622
return text_param_list
1723

1824

1925
def parse_vlm_json_data(json_data_list):
2026
text_param_list = []
2127
for json_data in json_data_list:
22-
prompt_data = {}
23-
if 'prompt' in json_data:
24-
if json_data['prompt'] != '':
25-
prompt_data["prompt"] = json_data['prompt']
26-
else:
27-
raise RuntimeError('== prompt should not be empty string ==')
28-
else:
29-
raise RuntimeError('== key word "prompt" does not exist ==')
28+
prompt_data = create_base_prompt(json_data)
29+
if ("media" in json_data) and ("video" in json_data):
30+
raise ValueError("only one key is avaialble from media & video")
3031
if "media" in json_data:
3132
prompt_data["media"] = json_data["media"]
33+
if "video" in json_data:
34+
prompt_data["video"] = json_data["video"]
3235
text_param_list.append(prompt_data)
3336
return text_param_list
3437

3538

3639
def parse_image_json_data(json_data_list):
3740
image_param_list = []
38-
for data in json_data_list:
39-
image_param = {}
40-
if 'prompt' in data:
41-
if data['prompt'] != '':
42-
image_param['prompt'] = data['prompt']
43-
else:
44-
raise RuntimeError('== prompt should not be empty string ==')
45-
else:
46-
raise RuntimeError('== key word "prompt" does not exist in prompt file ==')
47-
if 'width' in data:
48-
image_param['width'] = int(data['width'])
49-
if 'height' in data:
50-
image_param['height'] = int(data['height'])
51-
if 'steps' in data:
52-
image_param['steps'] = int(data['steps'])
53-
if 'guidance_scale' in data:
54-
image_param['guidance_scale'] = float(data['guidance_scale'])
55-
if 'media' in data:
56-
image_param['media'] = data['media']
57-
if 'mask_image' in data:
58-
image_param['mask_image'] = data['mask_image']
41+
for json_data in json_data_list:
42+
image_param = create_base_prompt(json_data)
43+
if 'width' in json_data:
44+
image_param['width'] = int(json_data['width'])
45+
if 'height' in json_data:
46+
image_param['height'] = int(json_data['height'])
47+
if 'steps' in json_data:
48+
image_param['steps'] = int(json_data['steps'])
49+
if 'guidance_scale' in json_data:
50+
image_param['guidance_scale'] = float(json_data['guidance_scale'])
51+
if 'media' in json_data:
52+
image_param['media'] = json_data['media']
53+
if 'mask_image' in json_data:
54+
image_param['mask_image'] = json_data['mask_image']
5955
image_param_list.append(image_param)
6056
return image_param_list
6157

6258

6359
def parse_speech_json_data(json_data_list):
6460
speech_param_list = []
6561
for json_data in json_data_list:
66-
speech_param = {}
67-
if 'media' in json_data:
68-
if json_data['media'] != '':
69-
speech_param['media'] = json_data['media']
70-
else:
71-
raise RuntimeError('== media path should not be empty string ==')
72-
else:
73-
raise RuntimeError('== key word "media" does not exist ==')
74-
if 'language' in json_data:
75-
speech_param['language'] = json_data['language']
76-
if 'timestamp' in json_data:
77-
speech_param['timestamp'] = json_data['timestamp']
62+
speech_param = create_base_prompt(json_data, "media")
63+
if "language" in json_data:
64+
speech_param["language"] = json_data["language"]
65+
if "timestamp" in json_data:
66+
speech_param["timestamp"] = json_data["timestamp"]
7867
speech_param_list.append(speech_param)
7968
return speech_param_list

tools/llm_bench/llm_bench_utils/prompt_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
# Copyright (C) 2023-2025 Intel Corporation
33
# SPDX-License-Identifier: Apache-2.0
44

5+
6+
import os
7+
import cv2
8+
import numpy as np
9+
from PIL import Image
10+
import logging as log
511
from .model_utils import get_param_from_file
612
from .parse_json_data import parse_text_json_data
713

@@ -17,3 +23,64 @@ def get_text_prompt(args):
1723
else:
1824
text_list.append(output_data_list[0])
1925
return text_list
26+
27+
28+
def print_video_frames_number_and_convert_to_tensor(func):
29+
def inner(video_path, decym_frames):
30+
log.info(f"Input video file: {video_path}")
31+
if decym_frames is not None:
32+
log.info(f"Requested to reduce into {decym_frames} frames")
33+
out_frames = func(video_path, decym_frames)
34+
log.info(f"Final frames number: {len(out_frames)}")
35+
return np.array(out_frames)
36+
return inner
37+
38+
39+
@print_video_frames_number_and_convert_to_tensor
40+
def make_video_tensor(video_path, decym_frames=None):
41+
supported_files = set([".mp4"])
42+
43+
assert os.path.exists(video_path), f"no input video file: {video_path}"
44+
assert video_path.suffix.lower() in supported_files, "no supported video file"
45+
cap = cv2.VideoCapture(video_path)
46+
47+
output_frames = []
48+
while True:
49+
ret, frame = cap.read()
50+
if not ret:
51+
break
52+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
53+
pil_image = Image.fromarray(frame_rgb)
54+
55+
shape = np.array(pil_image).shape
56+
dtype = np.array(pil_image).dtype
57+
log.info(f"Video shape: {shape}")
58+
log.info(f"Video dtype: {dtype}")
59+
new_frame = np.zeros(shape, dtype)
60+
61+
width, height = pil_image.size
62+
log.info(f"Video size: {width}x{height}")
63+
for x in range(0, width):
64+
for y in range(0, height):
65+
new_frame[y, x] = frame_rgb[y, x]
66+
output_frames.append(np.array(pil_image))
67+
68+
if decym_frames is None:
69+
return output_frames
70+
if int(decym_frames) == 0:
71+
return output_frames
72+
73+
# decimation procedure
74+
# decim_fames is required frame number if positive
75+
# or decimation factor if negative
76+
77+
decym_frames = int(decym_frames)
78+
if decym_frames > 0:
79+
if len(output_frames) <= decym_frames:
80+
return output_frames
81+
decym_factor = int(len(output_frames) / decym_frames)
82+
else:
83+
decym_factor = -decym_frames
84+
if decym_factor >= 2:
85+
return output_frames[::decym_factor]
86+
return output_frames

tools/llm_bench/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pillow
88
torch
99
transformers[sentencepiece]>=4.40.0
1010
diffusers>=0.22.0
11-
#optimum is in dependency list of optimum-intel
11+
#optimum is in dependency list of optimum-intel
1212
optimum-intel[nncf]>=1.25.0
1313
packaging
1414
psutil
@@ -21,3 +21,4 @@ scipy
2121
gguf_parser
2222
gguf>=0.10
2323
num2words
24+
opencv-python

tools/llm_bench/task/visual_language_generation.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,32 @@
1717
import llm_bench_utils.output_file
1818
import llm_bench_utils.gen_output_data as gen_output_data
1919
import llm_bench_utils.parse_json_data as parse_json_data
20+
import llm_bench_utils.prompt_utils as pu
2021
from pathlib import Path
2122

22-
2323
FW_UTILS = {'pt': llm_bench_utils.pt_utils, 'ov': llm_bench_utils.ov_utils}
2424

2525
DEFAULT_OUTPUT_TOKEN_SIZE = 512
2626

2727

2828
def run_visual_language_generation_optimum(
29-
inputs, num, model, processor, args, iter_data_list, md5_list, prompt_index, bench_hook, model_precision, proc_id, mem_consumption
30-
):
29+
inputs, num, model, processor, args, iter_data_list, md5_list, prompt_index,
30+
bench_hook, model_precision, proc_id, mem_consumption, required_frames=None):
3131
from optimum.intel.utils.import_utils import is_transformers_version
3232
set_seed(args['seed'])
3333
if args['batch_size'] != 1:
3434
log.warning("Only batch size 1 available for benchmarking")
3535
args["batch_size"] = 1
3636
images = []
3737
prompts = []
38+
videos = []
3839
inputs = [inputs] if not isinstance(inputs, (list, tuple)) else inputs
3940
for input_data in inputs:
40-
if input_data.get("media", None):
41+
if input_data.get("video", None):
42+
entry = Path(input_data["video"])
43+
video_tensor = pu.make_video_tensor(entry, required_frames)
44+
videos.append(video_tensor)
45+
elif input_data.get("media", None):
4146
entry = Path(input_data["media"])
4247
if entry.is_dir():
4348
for file in sorted(entry.iterdir()):
@@ -52,6 +57,8 @@ def run_visual_language_generation_optimum(
5257
llm_bench_utils.output_file.output_input_text(in_text, args, model_precision, prompt_index, bs_index, proc_id)
5358
tok_encode_start = time.perf_counter()
5459
input_data = model.preprocess_inputs(text=prompts[0], image=images[0] if images else None, **processor)
60+
if videos:
61+
input_data["videos"] = videos
5562
tok_encode_end = time.perf_counter()
5663
tok_encode_time = (tok_encode_end - tok_encode_start) * 1000
5764
# Remove `token_type_ids` from inputs
@@ -189,16 +196,21 @@ def load_image_genai(image_path):
189196

190197

191198
def run_visual_language_generation_genai(
192-
inputs, num, model, processor, args, iter_data_list, md5_list, prompt_index, streamer, model_precision, proc_id, mem_consumption
193-
):
199+
inputs, num, model, processor, args, iter_data_list, md5_list, prompt_index,
200+
streamer, model_precision, proc_id, mem_consumption, required_frames=None):
194201
if args['batch_size'] != 1:
195202
log.warning("Only batch size 1 available for benchmarking")
196203
args["batch_size"] = 1
197204
images = []
198205
prompts = []
206+
videos = []
199207
inputs = [inputs] if not isinstance(inputs, (list, tuple)) else inputs
200208
for input_data in inputs:
201-
if input_data.get("media", None):
209+
if input_data.get("video", None):
210+
entry = Path(input_data["video"])
211+
video_tensor = pu.make_video_tensor(entry, required_frames)
212+
videos.append(video_tensor)
213+
elif input_data.get("media", None):
202214
entry = Path(input_data["media"])
203215
if entry.is_dir():
204216
for file in sorted(entry.iterdir()):
@@ -222,8 +234,10 @@ def run_visual_language_generation_genai(
222234
gen_config.do_sample = False
223235
gen_config.ignore_eos = True
224236
kwargs = {}
225-
if len(images) >= 1:
237+
if images:
226238
kwargs["images"] = images
239+
if videos:
240+
kwargs["videos"] = videos
227241
prefix = '[warm-up]' if num == 0 else '[{}]'.format(num)
228242
log.info(f'{prefix}[P{prompt_index}] Input image nums:{len(images)}')
229243
start = time.perf_counter()
@@ -304,8 +318,11 @@ def run_visual_language_generation_genai(
304318
metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0], prompt_idx=prompt_index)
305319

306320

307-
def run_visual_language_generation_benchmark(model_path, framework, device, args, num_iters, mem_consumption):
308-
model, processor, pretrain_time, bench_hook, use_genai = FW_UTILS[framework].create_image_text_gen_model(model_path, device, mem_consumption, **args)
321+
def run_visual_language_generation_benchmark(
322+
model_path, framework, device, args, num_iters,
323+
mem_consumption, required_frames=None):
324+
outs = FW_UTILS[framework].create_image_text_gen_model(model_path, device, mem_consumption, **args)
325+
model, processor, pretrain_time, bench_hook, use_genai = outs
309326
model_precision = model_utils.get_model_precision(model_path.parts)
310327
iter_data_list = []
311328
md5_list = {num : {} for num in range(num_iters + 1)}
@@ -325,10 +342,10 @@ def run_visual_language_generation_benchmark(model_path, framework, device, args
325342
log.info(f"Numbeams: {args['num_beams']}, benchmarking iter nums(exclude warm-up): {num_iters}, "
326343
f'prompt nums: {len(image_text_list)}, prompt idx: {prompt_idx_list}')
327344

328-
if not use_genai:
329-
gen_fn = run_visual_language_generation_optimum
330-
else:
345+
if use_genai:
331346
gen_fn = run_visual_language_generation_genai
347+
else:
348+
gen_fn = run_visual_language_generation_optimum
332349

333350
proc_id = os.getpid()
334351
iter_timestamp = model_utils.init_timestamp(num_iters, image_text_list, prompt_idx_list)
@@ -337,41 +354,47 @@ def run_visual_language_generation_benchmark(model_path, framework, device, args
337354
for idx, input_text in enumerate(image_text_list):
338355
p_idx = prompt_idx_list[idx]
339356
if num == 0:
340-
metrics_print.print_unicode(f'[warm-up][P{p_idx}] Input text: {input_text}', max_output=metrics_print.MAX_INPUT_TXT_IN_LOG)
357+
metrics_print.print_unicode(f'[warm-up][P{p_idx}] Input text: {input_text}',
358+
max_output=metrics_print.MAX_INPUT_TXT_IN_LOG)
341359
iter_timestamp[num][p_idx]['start'] = datetime.datetime.now().isoformat()
342360
gen_fn(
343361
input_text, num, model, processor, args, iter_data_list, md5_list,
344-
p_idx, bench_hook, model_precision, proc_id, mem_consumption)
362+
p_idx, bench_hook, model_precision, proc_id, mem_consumption, required_frames)
345363
iter_timestamp[num][p_idx]['end'] = datetime.datetime.now().isoformat()
346-
prefix = '[warm-up]' if num == 0 else '[{}]'.format(num)
347-
log.info(f"{prefix}[P{p_idx}] start: {iter_timestamp[num][p_idx]['start']}, end: {iter_timestamp[num][p_idx]['end']}")
364+
prefix = f"[warm-up][P{p_idx}]" if num == 0 else f"[{num}][P{p_idx}]"
365+
log.info(f"{prefix} start: {iter_timestamp[num][p_idx]['start']}, end: {iter_timestamp[num][p_idx]['end']}")
348366
else:
349367
for idx, input_text in enumerate(image_text_list):
350368
p_idx = prompt_idx_list[idx]
351369
for num in range(num_iters + 1):
352370
if num == 0:
353-
metrics_print.print_unicode(f'[warm-up][P{p_idx}] Input text: {input_text}', max_output=metrics_print.MAX_INPUT_TXT_IN_LOG)
371+
metrics_print.print_unicode(f'[warm-up][P{p_idx}] Input text: {input_text}',
372+
max_output=metrics_print.MAX_INPUT_TXT_IN_LOG)
354373
iter_timestamp[num][p_idx]['start'] = datetime.datetime.now().isoformat()
355374
gen_fn(
356-
input_text, num, model, processor, args, iter_data_list, md5_list,
357-
prompt_idx_list[idx], bench_hook, model_precision, proc_id, mem_consumption)
375+
input_text, num, model, processor, args, iter_data_list, md5_list, prompt_idx_list[idx],
376+
bench_hook, model_precision, proc_id, mem_consumption, required_frames)
358377
iter_timestamp[num][p_idx]['end'] = datetime.datetime.now().isoformat()
359-
prefix = '[warm-up]' if num == 0 else '[{}]'.format(num)
360-
log.info(f"{prefix}[P{p_idx}] start: {iter_timestamp[num][p_idx]['start']}, end: {iter_timestamp[num][p_idx]['end']}")
378+
prefix = f"[warm-up][P{p_idx}]" if num == 0 else f"[{num}][P{p_idx}]"
379+
log.info(f"{prefix} start: {iter_timestamp[num][p_idx]['start']}, end: {iter_timestamp[num][p_idx]['end']}")
361380

362381
metrics_print.print_average(iter_data_list, prompt_idx_list, args['batch_size'], True)
363382
return iter_data_list, pretrain_time, iter_timestamp
364383

365384

366385
def get_image_text_prompt(args):
367386
vlm_file_list = []
368-
output_data_list, is_json_data = model_utils.get_param_from_file(args, ['media', "prompt"])
387+
output_data_list, is_json_data = model_utils.get_param_from_file(args, ["media", "prompt"])
369388
if is_json_data:
370389
vlm_param_list = parse_json_data.parse_vlm_json_data(output_data_list)
371390
if len(vlm_param_list) > 0:
372391
for vlm_file in vlm_param_list:
373-
if args['prompt_file'] is not None and len(args['prompt_file']) > 0:
374-
vlm_file['media'] = model_utils.resolve_media_file_path(vlm_file.get("media"), args['prompt_file'][0])
392+
if args['prompt_file'] is not None and len(args['prompt_file']) > 0 and 'media' in vlm_file:
393+
if 'video' in vlm_file:
394+
raise ValueError('media and video cannot be specify in a single prompt file')
395+
vlm_file['media'] = model_utils.resolve_media_file_path(vlm_file.get('media'), args['prompt_file'][0])
396+
elif args['prompt_file'] is not None and len(args['prompt_file']) > 0 and 'video' in vlm_file:
397+
vlm_file['video'] = model_utils.resolve_media_file_path(vlm_file.get('video'), args['prompt_file'][0])
375398
vlm_file_list.append(vlm_file)
376399
else:
377400
vlm_file_list.append(output_data_list)

0 commit comments

Comments
 (0)