@@ -243,7 +243,9 @@ def benchmark(
243243 print (get_system_info ())
244244
245245 if save_poses :
246- save_poses_to_files (video_path , save_dir , bodyparts , poses , timestamp = timestamp )
246+ individuals = dlc_live .read_config ()["metadata" ].get ("individuals" , [])
247+ n_individuals = len (individuals ) or 1
248+ save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp = timestamp )
247249
248250 return poses , times
249251
@@ -320,7 +322,7 @@ def draw_pose_and_write(
320322
321323 vwriter .write (image = frame )
322324
323- def save_poses_to_files (video_path , save_dir , bodyparts , poses , timestamp ):
325+ def save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp ):
324326 """
325327 Saves the detected keypoint poses from the video to CSV and HDF5 files.
326328
@@ -339,65 +341,48 @@ def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
339341 -------
340342 None
341343 """
344+ import pandas as pd
342345
343- base_filename = os .path .splitext (os .path .basename (video_path ))[0 ]
344- csv_save_path = os .path .join (save_dir , f"{ base_filename } _poses_{ timestamp } .csv" )
345- h5_save_path = os .path .join (save_dir , f"{ base_filename } _poses_{ timestamp } .h5" )
346-
347- # Save to CSV
348- with open (csv_save_path , mode = "w" , newline = "" ) as file :
349- writer = csv .writer (file )
350- header = ["frame" ] + [
351- f"{ bp } _{ axis } " for bp in bodyparts for axis in ["x" , "y" , "confidence" ]
352- ]
353- writer .writerow (header )
354- for entry in poses :
355- frame_num = entry ["frame" ]
356- pose = entry ["pose" ]["poses" ][0 ][0 ]
357- row = [frame_num ] + [
358- item .item () if isinstance (item , torch .Tensor ) else item
359- for kp in pose
360- for item in kp
361- ]
362- writer .writerow (row )
363-
364- # Save to HDF5
365- with h5py .File (h5_save_path , "w" ) as hf :
366- hf .create_dataset (name = "frames" , data = [entry ["frame" ] for entry in poses ])
367- for i , bp in enumerate (bodyparts ):
368- hf .create_dataset (
369- name = f"{ bp } _x" ,
370- data = [
371- (
372- entry ["pose" ]["poses" ][0 ][0 ][i , 0 ].item ()
373- if isinstance (entry ["pose" ]["poses" ][0 ][0 ][i , 0 ], torch .Tensor )
374- else entry ["pose" ]["poses" ][0 ][0 ][i , 0 ]
375- )
376- for entry in poses
377- ],
378- )
379- hf .create_dataset (
380- name = f"{ bp } _y" ,
381- data = [
382- (
383- entry ["pose" ]["poses" ][0 ][0 ][i , 1 ].item ()
384- if isinstance (entry ["pose" ]["poses" ][0 ][0 ][i , 1 ], torch .Tensor )
385- else entry ["pose" ]["poses" ][0 ][0 ][i , 1 ]
386- )
387- for entry in poses
388- ],
389- )
390- hf .create_dataset (
391- name = f"{ bp } _confidence" ,
392- data = [
393- (
394- entry ["pose" ]["poses" ][0 ][0 ][i , 2 ].item ()
395- if isinstance (entry ["pose" ]["poses" ][0 ][0 ][i , 2 ], torch .Tensor )
396- else entry ["pose" ]["poses" ][0 ][0 ][i , 2 ]
397- )
398- for entry in poses
399- ],
400- )
346+ base_filename = Path (video_path ).stem
347+ save_dir = Path (save_dir )
348+ h5_save_path = save_dir / f"{ base_filename } _poses_{ timestamp } .h5"
349+ csv_save_path = save_dir / f"{ base_filename } _poses_{ timestamp } .csv"
350+
351+ poses_array = _create_poses_np_array (n_individuals , bodyparts , poses )
352+ flattened_poses = poses_array .reshape (poses_array .shape [0 ], - 1 )
353+
354+ if n_individuals == 1 :
355+ pdindex = pd .MultiIndex .from_product (
356+ [bodyparts , ["x" , "y" , "likelihood" ]], names = ["bodyparts" , "coords" ]
357+ )
358+ else :
359+ individuals = [f"individual_{ i } " for i in range (n_individuals )]
360+ pdindex = pd .MultiIndex .from_product (
361+ [individuals , bodyparts , ["x" , "y" , "likelihood" ]], names = ["individuals" , "bodyparts" , "coords" ]
362+ )
363+
364+ pose_df = pd .DataFrame (flattened_poses , columns = pdindex )
365+
366+ pose_df .to_hdf (h5_save_path , key = "df_with_missing" , mode = "w" )
367+ pose_df .to_csv (csv_save_path , index = False )
368+
369+ def _create_poses_np_array (n_individuals : int , bodyparts : list , poses : list ):
370+ # Create numpy array with poses:
371+ max_frame = max (p ["frame" ] for p in poses )
372+ pose_target_shape = (n_individuals , len (bodyparts ), 3 )
373+ poses_array = np .full ((max_frame + 1 , * pose_target_shape ), np .nan )
374+
375+ for item in poses :
376+ frame = item ["frame" ]
377+ pose = item ["pose" ]
378+ if pose .ndim == 2 :
379+ pose = pose [np .newaxis , :, :]
380+ padded_pose = np .full (pose_target_shape , np .nan )
381+ slices = tuple (slice (0 , min (pose .shape [i ], pose_target_shape [i ])) for i in range (3 ))
382+ padded_pose [slices ] = pose [slices ]
383+ poses_array [frame ] = padded_pose
384+
385+ return poses_array
401386
402387
403388import argparse
0 commit comments