Skip to content

Commit dfe5817

Browse files
committed
Benchmark pytorch: fix save_poses_to_files()
1 parent b493bef commit dfe5817

File tree

1 file changed

+47
-60
lines changed

1 file changed

+47
-60
lines changed

dlclive/benchmark_pytorch.py

Lines changed: 47 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def benchmark(
243243
print(get_system_info())
244244

245245
if save_poses:
246-
save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp=timestamp)
246+
individuals = dlc_live.read_config()["metadata"].get("individuals", [])
247+
n_individuals = len(individuals) or 1
248+
save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp)
247249

248250
return poses, times
249251

@@ -320,7 +322,7 @@ def draw_pose_and_write(
320322

321323
vwriter.write(image=frame)
322324

323-
def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
325+
def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp):
324326
"""
325327
Saves the detected keypoint poses from the video to CSV and HDF5 files.
326328
@@ -330,6 +332,8 @@ def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
330332
Path to the analyzed video file.
331333
save_dir : str
332334
Directory where the pose data files will be saved.
335+
n_individuals: int
336+
Number of individuals
333337
bodyparts : list of str
334338
List of body part names corresponding to the keypoints.
335339
poses : list of dict
@@ -339,65 +343,48 @@ def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
339343
-------
340344
None
341345
"""
346+
import pandas as pd
342347

343-
base_filename = os.path.splitext(os.path.basename(video_path))[0]
344-
csv_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.csv")
345-
h5_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.h5")
346-
347-
# Save to CSV
348-
with open(csv_save_path, mode="w", newline="") as file:
349-
writer = csv.writer(file)
350-
header = ["frame"] + [
351-
f"{bp}_{axis}" for bp in bodyparts for axis in ["x", "y", "confidence"]
352-
]
353-
writer.writerow(header)
354-
for entry in poses:
355-
frame_num = entry["frame"]
356-
pose = entry["pose"]["poses"][0][0]
357-
row = [frame_num] + [
358-
item.item() if isinstance(item, torch.Tensor) else item
359-
for kp in pose
360-
for item in kp
361-
]
362-
writer.writerow(row)
363-
364-
# Save to HDF5
365-
with h5py.File(h5_save_path, "w") as hf:
366-
hf.create_dataset(name="frames", data=[entry["frame"] for entry in poses])
367-
for i, bp in enumerate(bodyparts):
368-
hf.create_dataset(
369-
name=f"{bp}_x",
370-
data=[
371-
(
372-
entry["pose"]["poses"][0][0][i, 0].item()
373-
if isinstance(entry["pose"]["poses"][0][0][i, 0], torch.Tensor)
374-
else entry["pose"]["poses"][0][0][i, 0]
375-
)
376-
for entry in poses
377-
],
378-
)
379-
hf.create_dataset(
380-
name=f"{bp}_y",
381-
data=[
382-
(
383-
entry["pose"]["poses"][0][0][i, 1].item()
384-
if isinstance(entry["pose"]["poses"][0][0][i, 1], torch.Tensor)
385-
else entry["pose"]["poses"][0][0][i, 1]
386-
)
387-
for entry in poses
388-
],
389-
)
390-
hf.create_dataset(
391-
name=f"{bp}_confidence",
392-
data=[
393-
(
394-
entry["pose"]["poses"][0][0][i, 2].item()
395-
if isinstance(entry["pose"]["poses"][0][0][i, 2], torch.Tensor)
396-
else entry["pose"]["poses"][0][0][i, 2]
397-
)
398-
for entry in poses
399-
],
400-
)
348+
base_filename = Path(video_path).stem
349+
save_dir = Path(save_dir)
350+
h5_save_path = save_dir / f"{base_filename}_poses_{timestamp}.h5"
351+
csv_save_path = save_dir / f"{base_filename}_poses_{timestamp}.csv"
352+
353+
poses_array = _create_poses_np_array(n_individuals, bodyparts, poses)
354+
flattened_poses = poses_array.reshape(poses_array.shape[0], -1)
355+
356+
if n_individuals == 1:
357+
pdindex = pd.MultiIndex.from_product(
358+
[bodyparts, ["x", "y", "likelihood"]], names=["bodyparts", "coords"]
359+
)
360+
else:
361+
individuals = [f"individual_{i}" for i in range(n_individuals)]
362+
pdindex = pd.MultiIndex.from_product(
363+
[individuals, bodyparts, ["x", "y", "likelihood"]], names=["individuals", "bodyparts", "coords"]
364+
)
365+
366+
pose_df = pd.DataFrame(flattened_poses, columns=pdindex)
367+
368+
pose_df.to_hdf(h5_save_path, key="df_with_missing", mode="w")
369+
pose_df.to_csv(csv_save_path, index=False)
370+
371+
def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
372+
# Create numpy array with poses:
373+
max_frame = max(p["frame"] for p in poses)
374+
pose_target_shape = (n_individuals, len(bodyparts), 3)
375+
poses_array = np.full((max_frame + 1, *pose_target_shape), np.nan)
376+
377+
for item in poses:
378+
frame = item["frame"]
379+
pose = item["pose"]
380+
if pose.ndim == 2:
381+
pose = pose[np.newaxis, :, :]
382+
padded_pose = np.full(pose_target_shape, np.nan)
383+
slices = tuple(slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3))
384+
padded_pose[slices] = pose[slices]
385+
poses_array[frame] = padded_pose
386+
387+
return poses_array
401388

402389

403390
import argparse

0 commit comments

Comments
 (0)