diff --git a/internnav/agent/internvla_n1_agent.py b/internnav/agent/internvla_n1_agent.py index 6e97e02..8e696af 100644 --- a/internnav/agent/internvla_n1_agent.py +++ b/internnav/agent/internvla_n1_agent.py @@ -1,22 +1,21 @@ -import os import copy -from PIL import Image -import numpy as np -import imageio -import base64 -import pickle +import os import threading -import cv2 -from gym import spaces import time + +import cv2 +import imageio +import numpy as np import torch +from gym import spaces +from PIL import Image -from internnav.configs.agent import StepRequest, AgentCfg from internnav.agent.base import Agent +from internnav.configs.agent import AgentCfg from internnav.configs.model.base_encoders import ModelCfg from internnav.model import get_config, get_policy from internnav.model.utils.misc import set_random_seed -from internnav.model.utils.vln_utils import S2Input, S2Output, S1Input, S1Output +from internnav.model.utils.vln_utils import S1Input, S1Output, S2Input, S2Output @Agent.register('internvla_n1') @@ -27,58 +26,55 @@ class InternVLAN1Agent(Agent): shape=(256, 256, 1), dtype=np.float32, ) - + def __init__(self, config: AgentCfg): super().__init__(config) set_random_seed(0) vln_sensor_config = self.config.model_settings self._model_settings = ModelCfg(**vln_sensor_config) - env_num = getattr(self._model_settings, 'env_num', 1) - sim_num = getattr(self._model_settings, 'sim_num', 1) self.device = torch.device(self._model_settings.device) self.mode = getattr(self._model_settings, 'infer_mode', 'sync') self.sys2_max_forward_step = getattr(self._model_settings, 'sys2_max_forward_step', 8) - + policy = get_policy(self._model_settings.policy_name) policy_config = get_config(self._model_settings.policy_name) model_config = {'model': self._model_settings.model_dump()} self.policy = policy(config=policy_config(model_cfg=model_config)) self.policy.eval() - + self.camera_intrinsic = self.get_intrinsic_matrix( self._model_settings.width, self._model_settings.height, self._model_settings.hfov ) - + self.episode_step = 0 self.episode_idx = 0 self.look_down = False - - - ### for async dual sys + + # for async dual sys self.pixel_goal_rgb = None self.pixel_goal_depth = None self.dual_forward_step = 0 self.sys1_infer_times = 0 - + self.sys1_depth_threshold = 5.0 self.sys1_forward_step = 4 - + self.s1_input = S1Input() self.s2_input = S2Input() self.s2_output = S2Output() self.s1_output = S1Output() - + # Thread management self.s2_thread = None - + # Thread locks self.s2_input_lock = threading.Lock() self.s2_output_lock = threading.Lock() self.s2_agent_lock = threading.Lock() - + # Start S2 thread self._start_s2_thread() - + # vis debug self.vis_debug = vln_sensor_config['vis_debug'] if self.vis_debug: @@ -87,7 +83,7 @@ def __init__(self, config: AgentCfg): self.fps_writer = imageio.get_writer(f"{self.debug_path}/fps_{self.episode_idx}.mp4", fps=5) self.fps_writer2 = imageio.get_writer(f"{self.debug_path}/fps_{self.episode_idx}_dp.mp4", fps=5) self.output_pixel = None - + def reset(self, reset_index=None): '''reset_index: [0]''' if reset_index is not None: @@ -97,7 +93,7 @@ def reset(self, reset_index=None): self.fps_writer2.close() else: self.episode_idx = -1 - + self.episode_step = 0 self.s1_input = S1Input() with self.s2_input_lock: @@ -105,21 +101,21 @@ def reset(self, reset_index=None): with self.s2_output_lock: self.s2_output = S2Output() self.s1_output = S1Output() - - ### for async dual sys + + # for async dual sys self.pixel_goal_rgb = None self.pixel_goal_depth = None - self.dual_forward_step = 0 + self.dual_forward_step = 0 self.sys1_infer_times = 0 - + # Reset s2 agent with self.s2_agent_lock: self.policy.reset() - + if self.vis_debug: self.fps_writer = imageio.get_writer(f"{self.debug_path}/fps_{self.episode_idx}.mp4", fps=5) self.fps_writer2 = imageio.get_writer(f"{self.debug_path}/fps_{self.episode_idx}_dp.mp4", fps=5) - + def get_intrinsic_matrix(self, width, height, hfov) -> np.ndarray: width = width height = height @@ -129,14 +125,11 @@ def get_intrinsic_matrix(self, width, height, hfov) -> np.ndarray: cx = (width - 1.0) / 2.0 cy = (height - 1.0) / 2.0 - intrinsic_matrix = np.array([ - [fx, 0.0, cx, 0.0], - [ 0.0, fy, cy, 0.0], - [ 0.0, 0.0, 1.0, 0.0], - [ 0.0, 0.0, 0.0, 1.0] - ]) + intrinsic_matrix = np.array( + [[fx, 0.0, cx, 0.0], [0.0, fy, cy, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + ) return intrinsic_matrix - + def _start_s2_thread(self): def s2_thread_func(): while True: @@ -149,7 +142,7 @@ def s2_thread_func(): else: time.sleep(0.5) # Sleep briefly if inference is not needed continue - + # # Check if currently inferring # if self.mode == "sync": # if not self.s2_output.is_infering: @@ -158,12 +151,19 @@ def s2_thread_func(): # else: # time.sleep(0.5) # Sleep briefly if already inferring # continue - + # Execute inference success = True try: with self.s2_agent_lock: - current_s2_output = self.policy.s2_step(self.s2_input.rgb, self.s2_input.depth, self.s2_input.pose, self.s2_input.instruction, self.camera_intrinsic, self.s2_input.look_down) + current_s2_output = self.policy.s2_step( + self.s2_input.rgb, + self.s2_input.depth, + self.s2_input.pose, + self.s2_input.instruction, + self.camera_intrinsic, + self.s2_input.look_down, + ) except Exception as e: print(f"s2 infer error: {e}") self.s2_output.is_infering = False @@ -171,22 +171,29 @@ def s2_thread_func(): success = False if not success: try: - current_s2_output = self.policy.s2_step(self.s2_input.rgb, self.s2_input.depth, self.s2_input.pose, self.s2_input.instruction, self.camera_intrinsic, False) + current_s2_output = self.policy.s2_step( + self.s2_input.rgb, + self.s2_input.depth, + self.s2_input.pose, + self.s2_input.instruction, + self.camera_intrinsic, + False, + ) except Exception as e: print(f"s2 infer error: {e}") self.s2_output.is_infering = False self.policy.reset() self.s2_output.output_pixel = None - self.s2_output.output_action = [0] # finish the inference + self.s2_output.output_action = [0] # finish the inference self.s2_output.output_latent = None continue - - print(f"s2 infer finish!!") + + print("s2 infer finish!!") # Update output state with self.s2_output_lock: - print(f"get s2 output lock") + print("get s2 output lock") # S2 output - + self.s2_output.output_pixel = current_s2_output.output_pixel self.s2_output.output_action = current_s2_output.output_action self.s2_output.output_latent = current_s2_output.output_latent @@ -199,20 +206,20 @@ def s2_thread_func(): self.s2_thread = threading.Thread(target=s2_thread_func) self.s2_thread.daemon = True self.s2_thread.start() - + def should_infer_s2(self, mode="sync"): """Function: Enables the sys2 inference thread depending on the mode. mode: just support 2 modes: "sync" and "partial_async". "sync": Synchronous mode (navdp_version >= 0.0), Sys1 and Sys2 execute in a sequential inference chain. - "partial_async": Asynchronous mode (navdp_version > 0.0, e.g., 0.1), + "partial_async": Asynchronous mode (navdp_version > 0.0, e.g., 0.1), Sys2 performs a single inference, while Sys1 performs multiple inference cycles. """ if self.episode_step == 0: return True - + if self.s2_output.is_infering: return False - + # 1. Synchronous mode: infer S2 every frame to provide to S1 for execution if mode == "sync": if self.s2_output.output_action is None: @@ -223,23 +230,27 @@ def should_infer_s2(self, mode="sync"): if mode == "partial_async": if self.dual_forward_step >= self.sys2_max_forward_step: return True - if self.s2_output.output_action is None and self.s2_output.output_pixel is None and self.s2_output.output_latent is None: + if ( + self.s2_output.output_action is None + and self.s2_output.output_pixel is None + and self.s2_output.output_latent is None + ): # This normally only occurs when output is discrete action and discrete action has been fully executed return True return False raise ValueError("Invalid mode: {}".format(mode)) - + def step(self, obs): - mode = self.mode #'sync', 'partial_async' - - obs = obs[0] # do not support batch_env currently? + mode = self.mode # 'sync', 'partial_async' + + obs = obs[0] # do not support batch_env currently? rgb = obs['rgb'] depth = obs['depth'] instruction = obs['instruction'] pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) - + # S2 inference is done in a separate thread - if self.should_infer_s2(mode) or self.look_down: # The look down frame must be inferred + if self.should_infer_s2(mode) or self.look_down: # The look down frame must be inferred print(f"======== Infer S2 at step {self.episode_step}========") with self.s2_input_lock: self.s2_input.idx = self.episode_step @@ -249,8 +260,8 @@ def step(self, obs): self.s2_input.instruction = instruction self.s2_input.should_infer = True self.s2_input.look_down = self.look_down - self.s2_output.is_infering = True #for async - + self.s2_output.is_infering = True # for async + self.dual_forward_step = 0 else: # Even if this frame doesn't do s2 inference, rgb needs to be provided to ensure history is correct @@ -258,17 +269,17 @@ def step(self, obs): # S1 inference is done in the main thread while self.s2_output.is_infering: time.sleep(0.5) - + while not self.s2_output.validate(): time.sleep(0.2) - + output = {} # Simple branch: # 1. If S2 output is full discrete actions, don't execute S1 and return directly print('===============', self.s2_output.output_action, '=================') if self.s2_output.output_action is not None: output['action'] = [self.s2_output.output_action[0]] - + with self.s2_output_lock: self.s2_output.output_action = self.s2_output.output_action[1:] if self.s2_output.output_action == []: @@ -286,34 +297,53 @@ def step(self, obs): self.look_down = False if self.sys1_infer_times > 0: self.dual_forward_step += 1 - + # print('Output action:', output, self.dual_forward_step) - + else: self.look_down = False # 2. If output is in latent form, execute latent S1 if self.s2_output.output_latent is not None: self.output_pixel = copy.deepcopy(self.s2_output.output_pixel) print(self.output_pixel) - + if mode != 'sync': - processed_pixel_rgb = np.array(Image.fromarray(self.s2_output.rgb_memory).resize((224, 224))) / 255.0 - processed_pixel_depth = np.array(Image.fromarray(self.s2_output.depth_memory[:,:,0]).resize((224, 224))) * 10.0 + processed_pixel_rgb = ( + np.array(Image.fromarray(self.s2_output.rgb_memory).resize((224, 224))) / 255.0 + ) + processed_pixel_depth = ( + np.array(Image.fromarray(self.s2_output.depth_memory[:, :, 0]).resize((224, 224))) * 10.0 + ) processed_pixel_depth[processed_pixel_depth > self.sys1_depth_threshold] = self.sys1_depth_threshold - + processed_rgb = np.array(Image.fromarray(rgb).resize((224, 224))) / 255.0 - processed_depth = np.array(Image.fromarray(depth[:,:,0]).resize((224, 224))) * 10.0 # should be 0-10m + processed_depth = ( + np.array(Image.fromarray(depth[:, :, 0]).resize((224, 224))) * 10.0 + ) # should be 0-10m processed_depth[processed_depth > self.sys1_depth_threshold] = self.sys1_depth_threshold - - rgbs = torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)]).unsqueeze(0).to(self.device) #[1, 2, 224, 224, 3] - depths = torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)]).unsqueeze(0).unsqueeze(-1).to(self.device)#[1, 2, 224, 224, 1] - self.s1_output = self.policy.s1_step_latent(rgbs, depths, self.s2_output.output_latent, use_async=True) + + rgbs = ( + torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)]) + .unsqueeze(0) + .to(self.device) + ) # [1, 2, 224, 224, 3] + depths = ( + torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)]) + .unsqueeze(0) + .unsqueeze(-1) + .to(self.device) + ) # [1, 2, 224, 224, 1] + self.s1_output = self.policy.s1_step_latent( + rgbs, depths, self.s2_output.output_latent, use_async=True + ) else: - self.s1_output = self.policy.s1_step_latent(rgb, depth * 10000.0, self.s2_output.output_latent, use_async=False) - + self.s1_output = self.policy.s1_step_latent( + rgb, depth * 10000.0, self.s2_output.output_latent, use_async=False + ) + else: assert False, f"S2 output should be either action or latent, but got neither! {self.s2_output}" - + if self.s1_output.idx == []: output['action'] = [-1] else: @@ -325,9 +355,8 @@ def step(self, obs): self.s2_output.output_action = None else: self.s2_output.output_action = None - - - self.s2_output.output_pixel = None #TODO: now just for visulization + + self.s2_output.output_pixel = None # TODO: now just for visulization if mode == 'sync': self.s2_output.output_latent = None else: @@ -339,31 +368,41 @@ def step(self, obs): self.sys1_infer_times += 1 self.dual_forward_step += 1 - + if self.dual_forward_step > self.sys2_max_forward_step: print("!!!!!!!!!!!!") print("ERR: self.dual_forward_step ", self.dual_forward_step, " > ", self.sys2_max_forward_step) print("!!!!!!!!!!!!") - + print('Output discretized traj:', output['action'], self.dual_forward_step) - - # Visualization + + # Visualization if self.vis_debug: vis = rgb.copy() if 'action' in output: vis = cv2.putText(vis, str(output['action'][0]), (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) if self.output_pixel is not None: pixel = self.output_pixel - vis = cv2.putText(vis, f"{pixel[1]}, {pixel[0]} ({self.s2_output.idx})", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) - + vis = cv2.putText( + vis, + f"{pixel[1]}, {pixel[0]} ({self.s2_output.idx})", + (50, 100), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 255, 0), + 2, + ) + cv2.circle(vis, (pixel[1], pixel[0]), 5, (0, 255, 0), -1) self.output_pixel = None self.fps_writer.append_data(vis) - + if self.s1_output.vis_image is not None: - Image.fromarray(self.s1_output.vis_image).save(os.path.join("./vis_debug_pix/", f"ttttt_{self.episode_step}.png")) + Image.fromarray(self.s1_output.vis_image).save( + os.path.join("./vis_debug_pix/", f"ttttt_{self.episode_step}.png") + ) self.fps_writer2.append_data(self.s1_output.vis_image) - + self.episode_step += 1 if 'action' in output: return [{'action': output['action'], 'ideal_flag': True}] diff --git a/internnav/configs/agent/__init__.py b/internnav/configs/agent/__init__.py index fcafb67..ee00ef0 100644 --- a/internnav/configs/agent/__init__.py +++ b/internnav/configs/agent/__init__.py @@ -5,9 +5,9 @@ class AgentCfg(BaseModel): server_host: str = 'localhost' - server_port: int = 5000 + server_port: int = 8087 model_name: str - ckpt_path: str + ckpt_path: str = None model_settings: Dict[str, Any] diff --git a/internnav/configs/evaluator/__init__.py b/internnav/configs/evaluator/__init__.py index ab770c5..27e63a3 100644 --- a/internnav/configs/evaluator/__init__.py +++ b/internnav/configs/evaluator/__init__.py @@ -59,9 +59,9 @@ class EvalCfg(BaseModel): eval_type: Optional[str] = None eval_settings: Optional[Dict[str, Any]] = {} agent: Optional[AgentCfg] = None - env: EnvCfg - task: TaskCfg - dataset: EvalDatasetCfg + env: EnvCfg = None + task: TaskCfg = None + dataset: EvalDatasetCfg = None __all__ = [ diff --git a/internnav/env/internutopia_env.py b/internnav/env/internutopia_env.py index ad7fcf1..4faee7c 100644 --- a/internnav/env/internutopia_env.py +++ b/internnav/env/internutopia_env.py @@ -1,7 +1,13 @@ +import os +import sys from typing import Any, Dict, List from internnav.configs.evaluator import EnvCfg, TaskCfg from internnav.env import base +from internnav.env.utils.episode_loader import ( + ResumablePathKeyEpisodeloader, + generate_vln_episode, +) @base.Env.register('internutopia') @@ -22,6 +28,23 @@ def __init__(self, env_config: EnvCfg, task_config: TaskCfg): super().__init__(env_config, task_config) env_settings = self.env_config.env_settings task_settings = self.task_config.task_settings + + # generate episodes + self.episode_loader = ResumablePathKeyEpisodeloader( + env_settings['dataset'].dataset_type, + **env_settings['dataset'].dataset_settings, + rank=env_settings['rank'], + world_size=env_settings['world_size'] + ) + self.episodes = generate_vln_episode(self.episode_loader, task_config) + if len(self.episodes) == 0: + print("No episodes found for the given configuration.") + sys.exit(0) + task_settings.update({'episodes': self.episodes}) + + # set visible device for isaac sim + os.environ["CUDA_VISIBLE_DEVICES"] = str(env_settings.get('local_rank', 0)) + config = Config( simulator=SimConfig(**env_settings), env_num=task_settings['env_num'], diff --git a/internnav/env/utils/episode_loader/__init__.py b/internnav/env/utils/episode_loader/__init__.py new file mode 100644 index 0000000..b3d58ad --- /dev/null +++ b/internnav/env/utils/episode_loader/__init__.py @@ -0,0 +1,2 @@ +from .generate_episode import generate_vln_episode +from .resumable import ResumablePathKeyEpisodeloader diff --git a/internnav/projects/dataloader/base.py b/internnav/env/utils/episode_loader/base.py similarity index 67% rename from internnav/projects/dataloader/base.py rename to internnav/env/utils/episode_loader/base.py index d00a67b..ab18b61 100644 --- a/internnav/projects/dataloader/base.py +++ b/internnav/env/utils/episode_loader/base.py @@ -1,9 +1,7 @@ -from internnav.evaluator.utils.common import load_data +from .dataset_utils import load_data, revise_one_data, skip_list -from .data_reviser import revise_one_data, skip_list - -class BasePathKeyDataloader: +class BasePathKeyEpisodeloader: def __init__( self, dataset_type, @@ -13,7 +11,15 @@ def __init__( filter_same_trajectory, revise_data=True, filter_stairs=True, + rank=0, + world_size=1, ): + # current supported dataset types in InternUtopia + # only kujiale has special scene path + # others type should be considered the same as mp3d in loading + allowed = ('R2RVLN', 'mp3d', 'kujiale', 'grscene') + assert dataset_type in allowed, f"Unsupported dataset type: {dataset_type}. Allowed: {allowed}" + self.path_key_data = {} self.path_key_scan = {} self.path_key_split = {} @@ -25,14 +31,19 @@ def __init__( filter_same_trajectory=filter_same_trajectory, filter_stairs=filter_stairs, dataset_type=dataset_type, + rank=rank, + world_size=world_size, ) for scan, path_list in load_data_map.items(): for path in path_list: trajectory_id = path['trajectory_id'] - if revise_data: + + # tiny revision for R2R dataset in MP3D to fit vlnpe task + if dataset_type == 'mp3d' and revise_data: if trajectory_id in skip_list: continue path = revise_one_data(path) + episode_id = path['episode_id'] path_key = f'{trajectory_id}_{episode_id}' path['start_position'] += robot_offset diff --git a/internnav/env/utils/episode_loader/dataset_utils.py b/internnav/env/utils/episode_loader/dataset_utils.py new file mode 100644 index 0000000..f740922 --- /dev/null +++ b/internnav/env/utils/episode_loader/dataset_utils.py @@ -0,0 +1,635 @@ +import copy +import gzip +import json +import os +from collections import defaultdict + +import numpy as np + +from internnav.utils.common_log_util import common_logger as log + +fall_path_z_0_3 = [ + 70, + 121, + 146, + 156, + 172, + 326, + 349, + 372, + 394, + 415, + 434, + 469, + 531, + 550, + 580, + 626, + 674, + 700, + 768, + 808, + 823, + 835, + 854, + 859, + 958, + 1009, + 1058, + 1065, + 1093, + 1105, + 1142, + 1205, + 1238, + 1245, + 1263, + 1290, + 1295, + 1353, + 1400, + 1403, + 1455, + 1470, + 1530, + 1644, + 1645, + 1650, + 1734, + 1771, + 1848, + 1876, + 1880, + 1893, + 1925, + 1928, + 1957, + 1967, + 1995, + 2051, + 2061, + 2100, + 2101, + 2102, + 2156, + 2173, + 2186, + 2252, + 2253, + 2296, + 2335, + 2360, + 2399, + 2441, + 2485, + 2502, + 2508, + 2530, + 2591, + 2609, + 2622, + 2632, + 2651, + 2676, + 2744, + 2752, + 2809, + 2871, + 2911, + 2951, + 2967, + 2968, + 2981, + 2991, + 3023, + 3031, + 3032, + 3078, + 3093, + 3115, + 3145, + 3156, + 3160, + 3183, + 3194, + 3291, + 3304, + 3351, + 3528, + 3534, + 3576, + 3596, + 3605, + 3629, + 3656, + 3665, + 3689, + 3733, + 3749, + 3789, + 3833, + 3838, + 3859, + 3863, + 3868, + 3890, + 3978, + 3984, + 3993, + 4005, + 4022, + 4112, + 4122, + 4136, + 4214, + 4257, + 4264, + 4281, + 4311, + 4318, + 4356, + 4407, + 4460, + 4467, + 4533, + 4536, + 4551, + 4586, + 4656, + 4694, + 4698, + 4725, + 4800, + 4805, + 4807, + 4848, + 4867, + 4927, + 4949, + 5103, + 5170, + 5176, + 5228, + 5325, + 5327, + 5427, + 5443, + 5462, + 5529, + 5552, + 5625, + 5660, + 5690, + 5703, + 5753, + 5757, + 5817, + 5900, + 5928, + 5948, + 5955, + 6004, + 6109, + 6113, + 6120, + 6141, + 6181, + 6206, + 6221, + 6260, + 6283, + 6404, + 6422, + 6529, + 6608, + 6631, + 6660, + 6713, + 6731, + 6736, + 6749, + 6786, + 6800, + 6913, + 6916, + 6938, + 6971, + 6993, + 7021, + 7052, + 7145, + 7180, + 7202, + 7264, + 3477, + 5197, + 6372, + 4175, + 5929, + 7029, + 1924, + 2376, + 4877, + 6463, + 765, + 4415, + 5133, + 59, + 246, + 592, + 604, + 952, + 1185, + 1362, + 2680, + 3727, + 839, + 1444, + 274, + 3265, + 3592, + 4514, + 5847, + 6005, + 6599, + 2461, + 3703, + 219, + 1731, + 1822, + 6055, + 6142, + 7289, + 5280, + 41, + 1982, + 2108, + 2247, + 2554, + 3853, + 4818, + 6768, + 6794, + 7003, + 7033, + 2733, + 4860, + 606, + 1200, + 1083, + 6039, + 651, + 797, + 1014, + 4006, + 5454, + 6826, + 6899, + 6933, + 6373, + 1415, + 1418, + 2457, + 4691, + 6342, + 621, + 602, + 946, + 5431, + 6163, + 6208, + 890, + 1668, + 2031, + 4161, + 4826, + 6183, + 1592, + 3645, + 4376, + 109, + 369, + 743, + 1432, + 2147, + 2190, + 3946, + 5720, + 6680, + 2994, + 3039, + 3781, + 4754, + 4920, + 6774, + 6942, + 2950, + 5624, + 3960, + 4890, + 4994, + 6036, + 2306, +] + +skip_list = [] + +fall_path_custom = { + 6558: [-1, 0, 0], + 454: [0.42, 0.9, 0], + 490: [0.97, 0.25, 0], + 910: [-0.4, 0, 0], + 1253: [-0.4, 0, 0], + 1834: [0, -0.5, 0.3], + 2004: [0.5, 0.5, 0], + 2249: [1, -1, 0], + 2382: [1, -0.5, 0], + 2468: [0.2, 0, 0], + 2498: [-0.2, -0.5, 0], + 2523: [1, 0, 0], + 2529: [1, 0.3, 0], + 2618: [-0.5, 0.2, 0.3], + 2688: [0, -1, 0], + 2768: [-0.86, 0.52, 0], + 3084: [0.88, -0.47, 0], + 3136: [1.0, 0, 0], + 3165: [0, 0, 0.8], + 3231: [0, -0.5, 0.3], + 3277: [0, 1, 0.3], + 3414: [0.5, 0, 0.3], + 3464: [0.7, -1, 0], + 3468: [-0.5, 0, 0], + 3686: [0.2, 0.2, 0], + 4073: [-0.24, 0.5, 0], + 4243: [0.2, 0, 0], + 4305: [0, -0.2, 0], + 4564: [-0.5, 0, 0], + 5252: [0.2, 0, 0.3], + 5328: [0, 0.5, 0], + 5401: [-1, -0.2, 0.0], + 5461: [-1.0, 0, 0.3], + 5560: [0, -0.5, 0.0], + 5609: [0.5, 0, 0.3], + 5930: [0.5, 0, 0], + 6262: [-0.5, 0, 0], + 6640: [0, -0.5, 0], + 6840: [0, -0.5, 0], + 6914: [0, -0.5, 0], + 7108: [0.5, 0, 0], + 7229: [0, -0.5, 0], + 7246: [0, 0.2, 0], + 7273: [0.5, 0, 0], + 338: [1, 1.2, 0.3], + 435: [0, 1, 0], + 2965: [0, 1, 0], + 3258: [0, 1, 0], + 1483: [0.5, 0, 0.3], + 5256: [0.8, 0, 0], + 1234: [0.2, -0.2, 0], + 1954: [0.2, -0.2, 0], + 2322: [0.2, 1, 0], + 6390: [0.2, 1, 0], + 6672: [0, 0.5, 0], + 5372: [0.5, 0, 0], + 2357: [0.3, -0.3, 0], + 95: [0.2, -0.5, 0], + 2778: [0.4, -0.5, 0], + 7281: [0.2, -0.5, 0], + 332: [-0.3, 0, 0], + 648: [-0.3, 0, 0], + 2716: [-0.2, 0, 0], + 2896: [0.2, 0.2, 0], + 3028: [0.2, 0.2, 0], + 3754: [0, 0.2, 0], + 4463: [-0.1, 0, 0], + 4615: [-0.1, 0, 0], + 5773: [-0.1, 0, 0], + 6783: [0.5, 0, 0], + 801: [0.5, 0, 0], + 5661: [0.5, 0, 0], + 675: [0, 0.5, 0], + 6526: [-0.5, 0, 0], + 7285: [-0.5, 0, 0], + 622: [0, -0.3, 0.3], + 4746: [0, -0.3, 0.3], + 1623: [0, -0.5, 0], + 5574: [0, 0.5, 0], + 1847: [0, 1.2, 0], + 2470: [0, 1.2, 0], + 2240: [-1, 0, 0], + 6694: [0, 0.2, 0], + 2180: [0.5, 0, 0], + 138: [0.5, -0.1, 0.1], + 175: [0.2, 0, 0], + 1899: [0.2, 0.2, 0], + 3858: [0, -2, 0.1], + 3952: [0.5, -0.1, 0.1], + 4156: [0.5, -0.1, 0.1], + 6077: [0.2, 0.2, 0], + 6875: [-0.2, -0.2, 0], + 7007: [-0.2, -0.2, 0], + 498: [0.5, 0, 0], + 3406: [0.5, 0, 0], + 3627: [-0.2, -0.5, 0], + 4239: [-0.3, 0, 0], + 412: [0, -0.1, 0], + 3347: [0, -0.1, 0], + 1944: [-0.2, -0.2, 0], + 2668: [-0.2, -0.2, 0], + 2749: [-0.5, 0, 0], + 1182: [0, -0.6, 0], +} + + +def revise_one_data(origin): + """ + Apply an offset amendment to the start position and first waypoint of the reference path + for a given trajectory, if it belongs to a known fall-amend trajectory group. + + The offset is selected based on: + - `fall_path_z_0_3` → fixed offset [0, 0, 0.3] + - `fall_path_custom` → custom offset mapped by trajectory_id + - otherwise → return original unchanged + + Args: + origin (dict): One navigation episode item containing keys such as + `trajectory_id`, `start_position`, and `reference_path`. + + Returns: + dict: The amended item with updated start position and first reference path waypoint, + or the original if no amendment rule matched. + """ + trajectory_id = origin['trajectory_id'] + if trajectory_id in fall_path_z_0_3: + amend_offset = [0, 0, 0.3] + elif trajectory_id in fall_path_custom: + amend_offset = fall_path_custom[trajectory_id] + else: + return origin + origin['start_position'][0] = origin['start_position'][0] + amend_offset[0] + origin['start_position'][1] = origin['start_position'][1] + amend_offset[1] + origin['start_position'][2] = origin['start_position'][2] + amend_offset[2] + origin['reference_path'][0][0] = origin['reference_path'][0][0] + amend_offset[0] + origin['reference_path'][0][1] = origin['reference_path'][0][1] + amend_offset[1] + origin['reference_path'][0][2] = origin['reference_path'][0][2] + amend_offset[2] + return origin + + +def transform_rotation_z_90degrees(rotation): + """ + Rotate a quaternion by 90 degrees (π/2 radians) around the Z axis. + """ + z_rot_90 = [np.cos(np.pi / 4), 0, 0, np.sin(np.pi / 4)] # 90 degrees = pi/2 radians + w1, x1, y1, z1 = rotation + w2, x2, y2, z2 = z_rot_90 + revised_rotation = [ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, # w + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, # x + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, # y + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, # z + ] + return revised_rotation + + +def has_stairs(item, height_threshold=0.3): + """ + Determine if a navigation reference path contains stair-like height jumps when the + instruction text includes the word 'stair'. + + The function checks the Z-height (3rd axis) differences between consecutive reference + waypoints and flags True if any jump exceeds the threshold. + + Args: + item (dict): Episode item containing `instruction.instruction_text` and `reference_path`. + height_threshold (float, optional): Minimum absolute height delta to consider as stairs. Defaults to 0.3. + + Returns: + bool: True if stairs are detected, False otherwise. + """ + has_stairs = False + if 'stair' in item['instruction']['instruction_text']: + latest_height = item['reference_path'][0][-1] + for index in range(1, len(item['reference_path'])): + position = item['reference_path'][index] + if abs(position[-1] - latest_height) >= height_threshold: + has_stairs = True + break + else: + latest_height = position[-1] + return has_stairs + + +def different_height(item): + """ + Check if multiple reference paths (or waypoints across paths) have significantly different + heights (Z-axis), indicating non-flat terrain. + + Args: + item (dict): Episode item containing a list of reference paths in `reference_path`. + + Returns: + bool: True if any adjacent path segment has a height difference > 0.3, else False. + """ + different_height = False + paths = item['reference_path'] + for path_idx in range(len(paths) - 1): + if abs(paths[path_idx + 1][2] - paths[path_idx][2]) > 0.3: + different_height = True + break + return different_height + + +def load_data( + dataset_root_dir, split, filter_same_trajectory=True, filter_stairs=True, dataset_type='mp3d', rank=0, world_size=1 +): + """ + Load a compressed navigation dataset split and organize episodes by scan/scene, + with optional filtering rules for duplicate trajectories and stair terrain. + + Supported behaviors include: + - Distributed slicing via `rank::world_size` + - Scene grouping by `scan` (kujiale/grscene) or `scene_id` (mp3d) + - Coordinate system remapping for mp3d (`x, z, y` → `[x, -y, z]`) + - Start rotation quaternion remapping + 90° Z rotation + - Filtering repeated `trajectory_id` + - Filtering episodes containing stairs or uneven heights + + Args: + dataset_root_dir (str): Root data directory containing the split folders. + split (str): Dataset split name (folder & file prefix), e.g. "val_unseen". + filter_same_trajectory (bool, optional): Remove episodes with duplicate trajectory_id. Defaults to True. + filter_stairs (bool, optional): Remove episodes where stairs or large height variation are detected. Defaults to True. + dataset_type (str, optional): Dataset source identifier, such as "mp3d", "kujiale", or "grscene". Defaults to "mp3d". + rank (int, optional): Distributed process rank used for slicing episodes. Defaults to 0. + world_size (int, optional): Number of distributed ranks used for striding. Defaults to 1. + + Returns: + dict: Mapping from `scan` → List of filtered episode items for that scene. + """ + with gzip.open(os.path.join(dataset_root_dir, split, f"{split}.json.gz"), 'rt', encoding='utf-8') as f: + data = json.load(f)['episodes'][rank::world_size] + + if dataset_type in ['kujiale', 'grscene']: + scenes = list(set([x['scan'] for x in data])) + else: + scenes = list(set([x['scene_id'] for x in data])) # e.g. 'mp3d/zsNo4HB9uLZ/zsNo4HB9uLZ.glb' + + scenes.sort() + new_data = {} + for scene in scenes: + if dataset_type in ['kujiale', 'grscene']: + scene_data = [x for x in data if x['scan'] == scene] + scan = scene + else: + scene_data = [x for x in data if x['scene_id'] == scene] + scan = scene.split('/')[1] # e.g. 'zsNo4HB9uLZ' + new_scene_data = [] + for item in scene_data: + new_item = copy.deepcopy(item) + new_item['scan'] = scan + new_item['original_start_position'] = item['start_position'] + new_item['original_start_rotation'] = item['start_rotation'] + if dataset_type == 'mp3d': + x, z, y = item['start_position'] + new_item['start_position'] = [x, -y, z] + r1, r2, r3, r4 = item['start_rotation'] + new_item['start_rotation'] = transform_rotation_z_90degrees([-r4, r1, r3, -r2]) + new_item['reference_path'] = [[x, -y, z] for x, z, y in item['reference_path']] + new_scene_data.append(new_item) + + new_data[scan] = new_scene_data + + data = copy.deepcopy(new_data) + new_data = defaultdict(list) + + # filter_same_trajectory + if filter_same_trajectory: + total_count = 0 + remaining_count = 0 + trajectory_list = [] + for scan, data_item in data.items(): + for item in data_item: + total_count += 1 + if item['trajectory_id'] in trajectory_list: + continue + remaining_count += 1 + trajectory_list.append(item['trajectory_id']) + new_data[scan].append(item) + log.info(f'[split:{split}]filter_same_trajectory remain: [ {remaining_count} / {total_count} ]') + data = new_data + new_data = defaultdict(list) + + if filter_stairs: + total_count = 0 + remaining_count = 0 + for scan, data_item in data.items(): + for item in data_item: + total_count += 1 + if has_stairs(item) or different_height(item): + continue + remaining_count += 1 + new_data[scan].append(item) + log.info(f'[split:{split}]filter_stairs remain: [ {remaining_count} / {total_count} ]') + data = new_data + + return data diff --git a/internnav/evaluator/utils/eval.py b/internnav/env/utils/episode_loader/generate_episode.py similarity index 52% rename from internnav/evaluator/utils/eval.py rename to internnav/env/utils/episode_loader/generate_episode.py index 9e6c447..b244045 100644 --- a/internnav/evaluator/utils/eval.py +++ b/internnav/env/utils/episode_loader/generate_episode.py @@ -1,11 +1,43 @@ -from internnav.configs.evaluator import EvalCfg -from internnav.evaluator.utils.common import load_kujiale_scene_usd, load_scene_usd -from internnav.projects.dataloader.resumable import ResumablePathKeyDataloader +import os +from internnav.configs.evaluator import TaskCfg +from internnav.utils.common_log_util import common_logger as log -def generate_episode(dataloader: ResumablePathKeyDataloader, config: EvalCfg): - scene_data_dir = config.task.scene.scene_data_dir - scene_asset_path = config.task.scene.scene_asset_path +from .resumable import ResumablePathKeyEpisodeloader + + +def load_scene_usd(mp3d_data_dir, scan): + """Load scene USD based on the scan""" + from internutopia.core.util import is_in_container + + find_flag = False + for root, dirs, files in os.walk(os.path.join(mp3d_data_dir, scan)): + target_file_name = 'fixed_docker.usd' if is_in_container() else 'fixed.usd' + for file in files: + if file == target_file_name: + scene_usd_path = os.path.join(root, file) + find_flag = True + break + if find_flag: + break + if not find_flag: + log.error('Scene USD not found for scan %s', scan) + return None + return scene_usd_path + + +def load_kujiale_scene_usd(kujiale_iros_data_dir, scan): + """Load scene USD based on the scan""" + scene_usd_path = os.path.join(kujiale_iros_data_dir, scan, f'{scan}.usda') + if not os.path.exists(scene_usd_path): + log.error('Scene USD not found for scan %s', scan) + return None + return scene_usd_path + + +def generate_vln_episode(dataloader: ResumablePathKeyEpisodeloader, task: TaskCfg): + scene_data_dir = task.scene.scene_data_dir + scene_asset_path = task.scene.scene_asset_path eval_path_key_list = dataloader.resumed_path_key_list path_key_data = dataloader.path_key_data episodes = [] @@ -21,9 +53,9 @@ def generate_episode(dataloader: ResumablePathKeyDataloader, config: EvalCfg): from internnav.env.utils.internutopia_extension.configs.tasks import VLNEvalTaskCfg robot = H1RobotCfg( - **config.task.robot.robot_settings, - controllers=[ControllerCfg(**cfg.controller_settings) for cfg in config.task.robot.controllers], - sensors=[RepCameraCfg(**cfg.sensor_settings) for cfg in config.task.robot.sensors], + **task.robot.robot_settings, + controllers=[ControllerCfg(**cfg.controller_settings) for cfg in task.robot.controllers], + sensors=[RepCameraCfg(**cfg.sensor_settings) for cfg in task.robot.sensors], ) for path_key in eval_path_key_list: @@ -33,23 +65,23 @@ def generate_episode(dataloader: ResumablePathKeyDataloader, config: EvalCfg): data['path_key'] = path_key data['name'] = dataloader.task_name - if config.task.scene.scene_type == 'kujiale': + if task.scene.scene_type == 'kujiale': load_scene_func = load_kujiale_scene_usd scene_scale = (1, 1, 1) else: load_scene_func = load_scene_usd scene_scale = (1, 1, 1) - robot_flash = getattr(config.task, "robot_flash", False) - one_step_stand_still = getattr(config.task, "one_step_stand_still", False) - if config.task.metric.metric_setting['metric_config'].get('name', None) is None: - config.task.metric.metric_setting['metric_config']['name'] = 'default_eval_name' + robot_flash = getattr(task, "robot_flash", False) + one_step_stand_still = getattr(task, "one_step_stand_still", False) + if task.metric.metric_setting['metric_config'].get('name', None) is None: + task.metric.metric_setting['metric_config']['name'] = 'default_eval_name' episodes.append( VLNEvalTaskCfg( - **config.task.task_settings, + **task.task_settings, robot_flash=robot_flash, one_step_stand_still=one_step_stand_still, - metrics=[VLNPEMetricCfg(**config.task.metric.metric_setting['metric_config'])], + metrics=[VLNPEMetricCfg(**task.metric.metric_setting['metric_config'])], scene_asset_path=load_scene_func(scene_data_dir, dataloader.path_key_scan[path_key]) if scene_asset_path == '' else scene_asset_path, diff --git a/internnav/projects/dataloader/resumable.py b/internnav/env/utils/episode_loader/resumable.py similarity index 81% rename from internnav/projects/dataloader/resumable.py rename to internnav/env/utils/episode_loader/resumable.py index 73cdd6d..506a672 100644 --- a/internnav/projects/dataloader/resumable.py +++ b/internnav/env/utils/episode_loader/resumable.py @@ -1,13 +1,14 @@ +import os + import lmdb import msgpack_numpy from internnav.evaluator.utils.config import get_lmdb_path -from .base import BasePathKeyDataloader -from .data_reviser import skip_list +from .base import BasePathKeyEpisodeloader -class ResumablePathKeyDataloader(BasePathKeyDataloader): +class ResumablePathKeyEpisodeloader(BasePathKeyEpisodeloader): def __init__( self, dataset_type, @@ -19,6 +20,8 @@ def __init__( run_type, retry_list, filter_stairs, + rank=0, + world_size=1, ): # 加载所有数据 super().__init__( @@ -29,23 +32,26 @@ def __init__( filter_same_trajectory=filter_same_trajectory, revise_data=True, filter_stairs=filter_stairs, + rank=rank, + world_size=world_size, ) self.task_name = task_name self.run_type = run_type self.lmdb_path = get_lmdb_path(task_name) self.retry_list = retry_list + + if not os.path.exists(self.lmdb_path): + os.makedirs(self.lmdb_path, exist_ok=True) + database = lmdb.open( - f'{self.lmdb_path}/sample_data.lmdb', + f'{self.lmdb_path}/sample_data{rank}.lmdb', map_size=1 * 1024 * 1024 * 1024 * 1024, - readonly=True, - lock=False, + lock=True, ) filtered_target_path_key_list = [] for path_key in self.path_key_data.keys(): - trajectory_id = int(path_key.split('_')[0]) - if trajectory_id in skip_list: - continue + # trajectory_id = int(path_key.split('_')[0]) with database.begin() as txn: value = txn.get(path_key.encode()) if value is None: diff --git a/internnav/env/utils/habitat_extensions/evaluator_single.py b/internnav/env/utils/habitat_extensions/evaluator_single.py deleted file mode 100644 index 24d95ea..0000000 --- a/internnav/env/utils/habitat_extensions/evaluator_single.py +++ /dev/null @@ -1,695 +0,0 @@ -import argparse -import copy -import itertools -import os -import random -import re -import sys -from collections import OrderedDict -from typing import Any - -import cv2 -import habitat -import numpy as np -import quaternion -import torch -from depth_camera_filtering import filter_depth -from habitat import Env -from habitat.config.default import get_agent_config -from habitat.config.default_structured_configs import ( - CollisionsMeasurementConfig, - FogOfWarConfig, - TopDownMapMeasurementConfig, -) -from habitat.utils.visualizations.utils import images_to_video, observations_to_image -from habitat_baselines.config.default import get_config as get_habitat_config -from omegaconf import OmegaConf -from PIL import Image -from transformers import AutoProcessor -from transformers.image_utils import to_numpy_array - -PROJECT_ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(PROJECT_ROOT_PATH) -print(f"PROJECT_ROOT_PATH {PROJECT_ROOT_PATH}") -from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM -from internnav.model.utils.vln_utils import ( - chunk_token, - split_and_clean, - traj_to_actions, -) -from internnav.utils.dist import get_rank, get_world_size, init_distributed_mode - -DEFAULT_IMAGE_TOKEN = "" - - -class VLNEvaluator: - def __init__( - self, - config_path: str, - split: str = "val_seen", - env_num: int = 1, - output_path: str = None, - model: Any = None, - processor: Any = None, - epoch: int = 0, - args: argparse.Namespace = None, - ): - self.args = args - self.device = torch.device('cuda') - self.split = split - self.env_num = env_num - self.save_video = args.save_video - self.output_path = output_path - self.epoch = epoch - self.config_path = config_path - self.config = get_habitat_config(config_path) - self.agent_config = get_agent_config(self.config.habitat.simulator) - self.sim_sensors_config = self.config.habitat.simulator.agents.main_agent.sim_sensors - - # for gradio evaluation - self.infer_data_ready = False - self.infer_scene_id = 0 - self.infer_episode_id = 0 - self.infer_success_cnt = -1 - self.infer_instruction = "" - self.infer_success = False - self.env = None - with habitat.config.read_write(self.config): - # self.config.habitat.task.measurements.success.success_distance=3.0 - self.config.habitat.dataset.split = self.split - self.config.habitat.task.measurements.update( - { - "top_down_map": TopDownMapMeasurementConfig( - map_padding=3, - map_resolution=1024, - draw_source=True, - draw_border=True, - draw_shortest_path=True, - draw_view_points=True, - draw_goal_positions=True, - draw_goal_aabbs=True, - fog_of_war=FogOfWarConfig( - draw=True, - visibility_dist=5.0, - fov=90, - ), - ), - "collisions": CollisionsMeasurementConfig(), - } - ) - - print(f"config = {type(self.config)}") - print(OmegaConf.to_yaml(self.config)) - - self._camera_height = self.sim_sensors_config.rgb_sensor.position[1] - self._min_depth = self.sim_sensors_config.depth_sensor.min_depth - self._max_depth = self.sim_sensors_config.depth_sensor.max_depth - - camera_fov_rad = np.deg2rad(self.sim_sensors_config.depth_sensor.hfov) - self._camera_fov = camera_fov_rad - self._fx = self._fy = self.sim_sensors_config.depth_sensor.width / (2 * np.tan(camera_fov_rad / 2)) - - self.model = model - self.processor = processor - - prompt = "You are an autonomous navigation assistant. Your task is to . Where should you go next to stay on track? Please output the next waypoint\'s coordinates in the image. Please output STOP when you have successfully completed the task." - answer = "" - self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}] - - self.conjunctions = [ - 'you can see ', - 'in front of you is ', - 'there is ', - 'you can spot ', - 'you are toward the ', - 'ahead of you is ', - 'in your sight is ', - ] - - self.actions2idx = OrderedDict( - { - 'STOP': [0], - "↑": [1], - "←": [2], - "→": [3], - "↓": [5], - } - ) - - self.num_frames = args.num_frames - self.num_future_steps = args.num_future_steps - self.num_history = args.num_history - - def preprocess_depth_image_v2( - self, depth_image, do_depth_scale=True, depth_scale=1000, target_height=None, target_width=None - ): - if target_height is None: - target_height = self.image_processor.crop_size['height'] # 384 - target_width = self.image_processor.crop_size['width'] # 384 - - resized_depth_image = depth_image.resize((target_width, target_height), Image.NEAREST) - - img = to_numpy_array(resized_depth_image) - if do_depth_scale: - img = img / depth_scale - - return img, (target_width, target_height) - - def get_intrinsic_matrix(self, sensor_cfg) -> np.ndarray: - width = sensor_cfg.width - height = sensor_cfg.height - fov = sensor_cfg.hfov - fx = (width / 2.0) / np.tan(np.deg2rad(fov / 2.0)) - fy = fx # Assuming square pixels (fx = fy) - cx = (width - 1.0) / 2.0 - cy = (height - 1.0) / 2.0 - - intrinsic_matrix = np.array( - [[fx, 0.0, cx, 0.0], [0.0, fy, cy, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) - return intrinsic_matrix - - def preprocess_instrinsic(self, intrinsic, ori_size, target_size): # (V, 4, 4) (resize_shape) (h, w) - intrinsic = copy.deepcopy(intrinsic) - if len(intrinsic.shape) == 2: - intrinsic = intrinsic[None, :, :] # (1, 4, 4) or (B, 4, 4) - - intrinsic[:, 0] /= ori_size[0] / target_size[0] # width - intrinsic[:, 1] /= ori_size[1] / target_size[1] # height - - # for crop transform - intrinsic[:, 0, 2] -= (target_size[0] - target_size[1]) / 2 - - if intrinsic.shape[0] == 1: - intrinsic = intrinsic.squeeze(0) - - return intrinsic - - def get_axis_align_matrix(self): - ma = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) - return ma - - def xyz_yaw_to_tf_matrix(self, xyz: np.ndarray, yaw: float) -> np.ndarray: - x, y, z = xyz - transformation_matrix = np.array( - [ - [np.cos(yaw), -np.sin(yaw), 0, x], - [np.sin(yaw), np.cos(yaw), 0, y], - [0, 0, 1, z], - [0, 0, 0, 1], - ] - ) - return transformation_matrix - - def xyz_pitch_to_tf_matrix(self, xyz: np.ndarray, pitch: float) -> np.ndarray: - x, y, z = xyz - transformation_matrix = np.array( - [ - [np.cos(pitch), 0, np.sin(pitch), x], - [0, 1, 0, y], - [-np.sin(pitch), 0, np.cos(pitch), z], - [0, 0, 0, 1], - ] - ) - return transformation_matrix - - def xyz_yaw_pitch_to_tf_matrix(self, xyz: np.ndarray, yaw: float, pitch: float) -> np.ndarray: - """Converts a given position and yaw, pitch angles to a 4x4 transformation matrix. - - Args: - xyz (np.ndarray): A 3D vector representing the position. - yaw (float): The yaw angle in radians. - pitch (float): The pitch angle in radians for y axis. - Returns: - np.ndarray: A 4x4 transformation matrix. - """ - x, y, z = xyz - rot1 = self.xyz_yaw_to_tf_matrix(xyz, yaw)[:3, :3] - rot2 = self.xyz_pitch_to_tf_matrix(xyz, pitch)[:3, :3] - transformation_matrix = np.eye(4) - transformation_matrix[:3, :3] = rot1 @ rot2 - transformation_matrix[:3, 3] = xyz - return transformation_matrix - - def pixel_to_gps(self, pixel, depth, intrinsic, tf_camera_to_episodic): - ''' - Args: - pixel: (2,) - [u, v] pixel coordinates - depth: (H, W) - depth image where depth[v, u] gives depth in meters - intrinsic: (4, 4) - camera intrinsic matrix - tf_camera_to_episodic: (4, 4) - transformation from camera to episodic frame - Returns: - (x, y): (x, y) coordinates in the episodic frame - ''' - v, u = pixel - z = depth[v, u] - print("depthhhhhhhhhhhhhh", z) - - x = (u - intrinsic[0, 2]) * z / intrinsic[0, 0] - y = (v - intrinsic[1, 2]) * z / intrinsic[1, 1] - point_camera = np.array([x, y, z, 1.0]) - - # Transform to episodic frame - point_episodic = tf_camera_to_episodic @ point_camera - point_episodic = point_episodic[:3] / point_episodic[3] - - x = point_episodic[0] - y = point_episodic[1] - - return (x, y) # same as habitat gps - - def config_env(self) -> Env: - env = Env(config=self.config) - # env.episodes = env.episodes[0:1] - return env - - def run_single_eval(self): - - self.model.eval() - self.env = self.config_env() - self.scene_episode_dict = {} - for episode in self.env.episodes: - if episode.scene_id not in self.scene_episode_dict: - self.scene_episode_dict[episode.scene_id] = [] - self.scene_episode_dict[episode.scene_id].append(episode) - intrinsic_matrix = self.get_intrinsic_matrix( - self.config.habitat.simulator.agents.main_agent.sim_sensors.rgb_sensor - ) - sucs, spls, oss, nes = [], [], [], [] - - if True: # fixme - scenes_keys = list(sorted(self.scene_episode_dict.keys())) - - self.infer_success = False - self.infer_data_ready = False - print('---------------current infer scene:', scenes_keys[self.infer_scene_id]) - selected_scenes = ['17DRP5sb8fy', 'r1Q1Z4BcV1o', 'dhjEzFoUFzH'] - key_name = ( - 'data/scene_datasets/mp3d/' - + selected_scenes[self.infer_scene_id] - + '/' - + selected_scenes[self.infer_scene_id] - + '.glb' - ) - episodes = self.scene_episode_dict[key_name] - step_size = len(episodes) // 6 - # episode_id = 0 - - episode = episodes[self.infer_episode_id * step_size] - - episode_instruction = ( - self.infer_instruction - ) # episode.instruction.instruction_text if 'objectnav' not in self.config_path else episode.object_category - print("episode start", episode_instruction) - - # episode_id = int(episode.episode_id) - env = self.env - env.current_episode = episode - observations = env.reset() - - agent_state = env.sim.get_agent_state() - rotation = agent_state.rotation - translation = agent_state.position - rotation_matrix = quaternion.as_rotation_matrix(rotation) - transformation_matrix = np.eye(4) - transformation_matrix[:3, :3] = rotation_matrix - transformation_matrix[:3, 3] = translation - - # agent = ShortestPathFollower(env.sim, 0.25, False) - - os.makedirs(os.path.join(self.output_path, f'check_sim_{self.epoch}'), exist_ok=True) - - vis_frames = [] - step_id = 0 - - if self.save_video: - os.makedirs(self.output_path, exist_ok=True) - initial_height = env.sim.get_agent_state().position[1] - - rgb_list = [] - # depth_list = [] - action_seq = [] - # past_key_values = None - output_ids = None - - goal = None - action = None - # look_down_observations = None - # look_down_id_list = [] - messages = [] - # last_action = None - local_actions = [] - - # begin evaluation main loop - while not env.episode_over and step_id <= 70: - rgb = observations["rgb"] - depth = observations["depth"] - x, y = observations["gps"] - camera_yaw = observations["compass"][0] - depth = filter_depth(depth.reshape(depth.shape[:2]), blur_type=None) - depth = depth * (self._max_depth - self._min_depth) + self._min_depth - depth = depth * 1000 - - agent_state = env.sim.get_agent_state() - height = agent_state.position[1] - initial_height # Habitat GPS makes west negative, so flip y - camera_position = np.array([x, -y, self._camera_height + height]) - # robot_xy = camera_position[:2] - # tf_camera_to_episodic = self.xyz_yaw_to_tf_matrix(camera_position, camera_yaw) @ self.get_axis_align_matrix() - tf_camera_to_episodic = ( - self.xyz_yaw_pitch_to_tf_matrix(camera_position, camera_yaw, np.deg2rad(30)) - @ self.get_axis_align_matrix() - ) - - image = Image.fromarray(rgb).convert('RGB') # raw observation image - # image_size = image.size # 640*480 - save_raw_image = image.copy() - - if action == 5: - look_down_image = image # Image.fromarray(look_down_observations['rgb']).convert('RGB') - save_raw_image = look_down_image.copy() - - # rgb_list.append(look_down_image) - look_down_depth, resize_shape = self.preprocess_depth_image_v2( - Image.fromarray(depth.astype(np.uint16), mode='I;16'), - do_depth_scale=True, - depth_scale=1000, - target_height=224, - target_width=224, - ) - look_down_depth = torch.as_tensor(np.ascontiguousarray(look_down_depth)).float() # [H, W] - # depth clip to 5m - look_down_depth[look_down_depth > 5.0] = 5.0 - else: - image = image.resize((self.args.resize_w, self.args.resize_h)) - rgb_list.append(image) - - down_observations = env.step(5) - down_observations = env.step(5) - - look_down_image = Image.fromarray(down_observations["rgb"]).convert('RGB') - depth = down_observations["depth"] - depth = filter_depth(depth.reshape(depth.shape[:2]), blur_type=None) - depth = depth * (self._max_depth - self._min_depth) + self._min_depth - depth = depth * 1000 - look_down_depth, resize_shape = self.preprocess_depth_image_v2( - Image.fromarray(depth.astype(np.uint16), mode='I;16'), - do_depth_scale=True, - depth_scale=1000, - target_height=224, - target_width=224, - ) - look_down_depth = torch.as_tensor(np.ascontiguousarray(look_down_depth)).float() # [H, W] - # depth clip to 5m - look_down_depth[look_down_depth > 5.0] = 5.0 - - env.step(4) - env.step(4) - - info = env.get_metrics() - current_frame_infer_pixel = False - if len(action_seq) == 0 and goal is None: # 只有执行完一次输出的所有 action_seq 才能继续做模型推理 - if action != 5: - sources = copy.deepcopy(self.conversation) - if 'objectnav' in self.config_path: - sources[0]["value"] = sources[0]["value"].replace( - '.', - random.choice(self.objectnav_instructions).format( - target_object=episode.object_category.replace('_', ' ') - ), - ) - else: - sources[0]["value"] = sources[0]["value"].replace( - '.', episode_instruction[:-1] - ) - cur_images = rgb_list[-1:] # current observation - if step_id == 0: - history_id = [] - else: - history_id = np.unique( - np.linspace(0, step_id - 1, self.num_history, dtype=np.int32) - ).tolist() - placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id) - sources[0]["value"] += f' These are your historical observations: {placeholder}.' - - history_id = sorted(history_id) - print('history_idddddddd', step_id, history_id) - input_images = [rgb_list[i] for i in history_id] + cur_images - input_img_id = 0 - else: - assert action == 5 # last action is look down - sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}] - input_images += [look_down_image] - # messages.append({'role': 'assistant', 'content': [{'type': 'text', 'text': llm_outputs}]}) - input_img_id = -1 - - prompt = random.choice(self.conjunctions) + DEFAULT_IMAGE_TOKEN - sources[0]["value"] += f" {prompt}." - print('sources', step_id, sources) - prompt_instruction = copy.deepcopy(sources[0]["value"]) - parts = split_and_clean(prompt_instruction) - - content = [] - for i in range(len(parts)): - if parts[i] == "": - content.append({"type": "image", "image": input_images[input_img_id]}) - input_img_id += 1 - else: - content.append({"type": "text", "text": parts[i]}) - - messages.append({'role': 'user', 'content': content}) - - print('step_id', step_id, 'messages:', messages) - - text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - - inputs = self.processor(text=[text], images=input_images, return_tensors="pt").to(self.model.device) - - with torch.no_grad(): - output_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False) - - llm_outputs = self.processor.tokenizer.decode( - output_ids[0][inputs.input_ids.shape[1] :], skip_special_tokens=True - ) - print('step_id:', step_id, 'output text:', llm_outputs) - - if bool(re.search(r'\d', llm_outputs)): # output pixel goal - current_frame_infer_pixel = True - forward_action = 0 - coord = [int(c) for c in re.findall(r'\d+', llm_outputs)] - pixel_goal = [int(coord[1]), int(coord[0])] # switch the goal o - - goal = self.pixel_to_gps(pixel_goal, depth / 1000, intrinsic_matrix, tf_camera_to_episodic) - print('before', goal, depth.shape) - goal = (transformation_matrix @ np.array([-goal[1], 0, -goal[0], 1]))[:3] - - if not env.sim.pathfinder.is_navigable(np.array(goal)): - goal = np.array(env.sim.pathfinder.snap_point(np.array(goal))) - - # look down --> horizontal - env.step(4) - env.step(4) - - # action = agent.get_next_action(goal) - local_actions = [] - pixel_values = inputs.pixel_values - image_grid_thw = torch.cat([thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0) - - with torch.no_grad(): - traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw) - - # image_dp = torch.tensor(np.array(look_down_image.resize((224, 224)))).to(torch.bfloat16) - # pix_goal_image = copy.copy(image_dp) - # images_dp = torch.stack([pix_goal_image, image_dp]).unsqueeze(0) - # depth_dp = look_down_depth.unsqueeze(-1).to(torch.bfloat16) - # pix_goal_depth = copy.copy(depth_dp) - # depths_dp = torch.stack([pix_goal_depth, depth_dp]).unsqueeze(0) - - with torch.no_grad(): - dp_actions = self.model.generate_traj(traj_latents) - - random_choice = np.random.choice(dp_actions.shape[0]) - if self.args.continuous_traj: - action_list = traj_to_actions(dp_actions) - if len(action_list) < 8: - action_list += [0] * (8 - len(action_list)) - else: - action_list = chunk_token(dp_actions[random_choice]) - print("first action_list", action_list) - local_actions = action_list - action = local_actions[0] - if action == 0: - goal = None - output_ids = None - action = 2 - print('conduct a random action 2') - observations = env.step(action) - step_id += 1 - messages = [] - continue - - print('predicted goal', pixel_goal, goal, flush=True) - else: - action_seq = self.parse_actions(llm_outputs) - print('actions', action_seq, flush=True) - - if len(action_seq) != 0: - action = action_seq[0] - action_seq.pop(0) - elif goal is not None: - if len(local_actions) != 0: - action = local_actions.pop(0) - else: - action = 0 - forward_action += 1 - print('forward_action', forward_action, flush=True) - if forward_action > 8: - goal = None - output_ids = None - messages = [] - step_id += 1 - forward_action = 0 - local_actions = [] - continue - if action == 0: - goal = None - output_ids = None - messages = [] - step_id += 1 - forward_action = 0 - local_actions = [] - continue - else: - action = 0 - - if info['top_down_map'] is not None: - frame = observations_to_image({'rgb': np.asarray(save_raw_image)}, info) - if current_frame_infer_pixel: - frame = cv2.putText( - frame, - f"{pixel_goal[1], pixel_goal[0]}", - (50, 80), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (255, 0, 0), - 2, - ) - frame = cv2.circle(frame, (pixel_goal[1], pixel_goal[0]), 5, (255, 0, 0), -1) - else: - output_str = str(action) - output_str = ( - output_str.replace('1', "Go forward").replace('2', 'Turn left').replace('3', 'Turn right') - ) - output_str = output_str.replace('5', 'Look down').replace('0', 'Stop!') - frame = cv2.putText(frame, output_str, (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2) - vis_frames.append(frame) - if action == 5: - vis_frames.append(frame) - - print("step_id", step_id, "action", action) - - if action == 5: - env.step(action) - observations = env.step(action) - else: - observations = env.step(action) - step_id += 1 - messages = [] - - self.infer_success_cnt += 1 - - # metrics = env.get_metrics() - if self.save_video: - images_to_video(vis_frames, self.output_path, f"res_{self.infer_success_cnt}", fps=6, quality=9) - self.infer_success = True - vis_frames.clear() - - env.close() - return ( - torch.tensor(sucs).to(self.device), - torch.tensor(spls).to(self.device), - torch.tensor(oss).to(self.device), - torch.tensor(nes).to(self.device), - torch.tensor(len(sucs)).to(self.device), - ) - - def parse_actions(self, output): - action_patterns = '|'.join(re.escape(action) for action in self.actions2idx) - # import ipdb; ipdb.set_trace() - regex = re.compile(action_patterns) - matches = regex.findall(output) - actions = [self.actions2idx[match] for match in matches] - actions = itertools.chain.from_iterable(actions) - return list(actions) - - def preprocess_qwenvl(self, source): - prompt = random.choice(self.conjunctions) + DEFAULT_IMAGE_TOKEN - if len(source[0]["value"]) != 0: - source[0]["value"] += f" {prompt}." - else: - source[0]["value"] = f"{prompt}." # Please output the next waypoint\'s coordinates in the image." - return source - - -def eval(): - global local_rank - - parser = argparse.ArgumentParser() - parser.add_argument("--local_rank", default=0, type=int, help="node rank") - parser.add_argument("--model_path", type=str, default="") - parser.add_argument("--habitat_config_path", type=str, default='scripts/eval/configs/vln_r2r.yaml') - parser.add_argument("--eval_split", type=str, default='val_unseen') - parser.add_argument("--output_path", type=str, default='./exps_pix/val_unseen/debug_coord_wm') - parser.add_argument("--num_future_steps", type=int, default=4) - parser.add_argument("--num_frames", type=int, default=32) - parser.add_argument("--save_video", action="store_true", default=True) - parser.add_argument("--num_history", type=int, default=8) - parser.add_argument("--resize_w", type=int, default=384) - parser.add_argument("--resize_h", type=int, default=384) - parser.add_argument("--predict_step_nums", type=int, default=16) - parser.add_argument("--continuous_traj", action="store_true", default=False) - parser.add_argument("--max_new_tokens", type=int, default=1024) - - parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') - parser.add_argument('--rank', default=0, type=int, help='rank') - parser.add_argument('--gpu', default=0, type=int, help='gpu') - parser.add_argument('--port', default='2333') - parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') - parser.add_argument('--device', default='cuda', help='device to use for training / testing') - - args = parser.parse_args() - init_distributed_mode(args) - local_rank = args.local_rank - np.random.seed(local_rank) - # Load model and tokenizer - processor = AutoProcessor.from_pretrained(args.model_path) - processor.tokenizer.padding_side = 'left' - - device = torch.device(f"cuda:{local_rank}") - model = InternVLAN1ForCausalLM.from_pretrained( - args.model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map={"": device} - ) - - # start the evaluation - evaluate(model, processor, args) - - -def evaluate(model, processor, args): - model.eval() - world_size = get_world_size() - evaluator = VLNEvaluator( - config_path=args.habitat_config_path, - split=args.eval_split, - env_num=world_size, - output_path=args.output_path, - model=model, - processor=processor, - epoch=0, - args=args, - ) - sucs, spls, oss, nes, ep_num = evaluator.eval_action(idx=get_rank()) - - -if __name__ == "__main__": - eval() - print('hello world!!') - sys.exit(0) diff --git a/internnav/env/utils/habitat_extensions/fonts/arial.ttf b/internnav/env/utils/habitat_extensions/fonts/arial.ttf deleted file mode 100644 index 886789b..0000000 Binary files a/internnav/env/utils/habitat_extensions/fonts/arial.ttf and /dev/null differ diff --git a/internnav/env/utils/habitat_extensions/maps.py b/internnav/env/utils/habitat_extensions/maps.py deleted file mode 100644 index 68f62d7..0000000 --- a/internnav/env/utils/habitat_extensions/maps.py +++ /dev/null @@ -1,370 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union - -import networkx as nx -import numpy as np -from habitat.core.simulator import Simulator -from habitat.core.utils import try_cv2_import -from habitat.tasks.vln.vln import VLNEpisode -from habitat.utils.visualizations import maps as habitat_maps - -cv2 = try_cv2_import() - -AGENT_SPRITE = habitat_maps.AGENT_SPRITE - -MAP_THICKNESS_SCALAR: int = 128 - -MAP_INVALID_POINT = 0 -MAP_VALID_POINT = 1 -MAP_BORDER_INDICATOR = 2 -MAP_SOURCE_POINT_INDICATOR = 4 -MAP_TARGET_POINT_INDICATOR = 6 -MAP_MP3D_WAYPOINT = 7 -MAP_VIEW_POINT_INDICATOR = 8 -MAP_TARGET_BOUNDING_BOX = 9 -MAP_REFERENCE_POINT = 10 -MAP_MP3D_REFERENCE_PATH = 11 -MAP_WAYPOINT_PREDICTION = 12 -MAP_ORACLE_WAYPOINT = 13 -MAP_SHORTEST_PATH_WAYPOINT = 14 - -TOP_DOWN_MAP_COLORS = np.full((256, 3), 150, dtype=np.uint8) -TOP_DOWN_MAP_COLORS[15:] = cv2.applyColorMap(np.arange(241, dtype=np.uint8), cv2.COLORMAP_JET).squeeze(1)[:, ::-1] -TOP_DOWN_MAP_COLORS[MAP_INVALID_POINT] = [255, 255, 255] # White -TOP_DOWN_MAP_COLORS[MAP_VALID_POINT] = [150, 150, 150] # Light Grey -TOP_DOWN_MAP_COLORS[MAP_BORDER_INDICATOR] = [50, 50, 50] # Grey -TOP_DOWN_MAP_COLORS[MAP_SOURCE_POINT_INDICATOR] = [0, 0, 200] # Blue -TOP_DOWN_MAP_COLORS[MAP_TARGET_POINT_INDICATOR] = [200, 0, 0] # Red -TOP_DOWN_MAP_COLORS[MAP_MP3D_WAYPOINT] = [0, 200, 0] # Green -TOP_DOWN_MAP_COLORS[MAP_VIEW_POINT_INDICATOR] = [245, 150, 150] # Light Red -TOP_DOWN_MAP_COLORS[MAP_TARGET_BOUNDING_BOX] = [0, 175, 0] # Dark Green -TOP_DOWN_MAP_COLORS[MAP_REFERENCE_POINT] = [0, 0, 0] # Black -TOP_DOWN_MAP_COLORS[MAP_MP3D_REFERENCE_PATH] = [0, 0, 0] # Black -TOP_DOWN_MAP_COLORS[MAP_WAYPOINT_PREDICTION] = [255, 255, 0] # Yellow -TOP_DOWN_MAP_COLORS[MAP_ORACLE_WAYPOINT] = [255, 165, 0] # Orange -TOP_DOWN_MAP_COLORS[MAP_SHORTEST_PATH_WAYPOINT] = [0, 150, 0] # Dark Green - - -def get_top_down_map(sim, map_resolution, meters_per_pixel): - base_height = sim.get_agent(0).state.position[1] - td_map = habitat_maps.get_topdown_map( - sim.pathfinder, - base_height, - map_resolution, - False, - meters_per_pixel, - ) - return td_map - - -def colorize_top_down_map( - top_down_map: np.ndarray, - fog_of_war_mask: Optional[np.ndarray] = None, - fog_of_war_desat_amount: float = 0.5, -) -> np.ndarray: - """Same as `maps.colorize_topdown_map` in Habitat-Lab, but with different - colors. - """ - _map = TOP_DOWN_MAP_COLORS[top_down_map] - - if fog_of_war_mask is not None: - fog_of_war_desat_values = np.array([[fog_of_war_desat_amount], [1.0]]) - # Only desaturate valid points as only valid points get revealed - desat_mask = top_down_map != MAP_INVALID_POINT - - _map[desat_mask] = (_map * fog_of_war_desat_values[fog_of_war_mask]).astype(np.uint8)[desat_mask] - - return _map - - -def static_to_grid( - realworld_x: float, - realworld_y: float, - grid_resolution: Tuple[int, int], - bounds: Dict[str, Tuple[float, float]], -) -> Tuple[int, int]: - """Return gridworld index of realworld coordinates assuming top-left - corner is the origin. The real world coordinates of lower left corner are - (coordinate_min, coordinate_min) and of top right corner are - (coordinate_max, coordinate_max). Same as the habitat-Lab maps.to_grid - function but with a static `bounds` instead of requiring a simulator or - pathfinder instance. - """ - grid_size = ( - abs(bounds["upper"][2] - bounds["lower"][2]) / grid_resolution[0], - abs(bounds["upper"][0] - bounds["lower"][0]) / grid_resolution[1], - ) - grid_x = int((realworld_x - bounds["lower"][2]) / grid_size[0]) - grid_y = int((realworld_y - bounds["lower"][0]) / grid_size[1]) - return grid_x, grid_y - - -def drawline( - img: np.ndarray, - pt1: Union[Tuple[float], List[float]], - pt2: Union[Tuple[float], List[float]], - color: List[int], - thickness: int = 1, - style: str = "dotted", - gap: int = 15, -) -> None: - """https://stackoverflow.com/questions/26690932/opencv-rectangle-with-dotted-or-dashed-lines - style: "dotted", "dashed", or "filled" - """ - assert style in ["dotted", "dashed", "filled"] - - if style == "filled": - cv2.line(img, pt1, pt2, color, thickness) - return - - dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5 - pts = [] - for i in np.arange(0, dist, gap): - r = i / dist - x = int((pt1[0] * (1 - r) + pt2[0] * r) + 0.5) - y = int((pt1[1] * (1 - r) + pt2[1] * r) + 0.5) - pts.append((x, y)) - - if style == "dotted": - for p in pts: - cv2.circle(img, p, thickness, color, -1) - else: - s = pts[0] - e = pts[0] - for i, p in enumerate(pts): - s = e - e = p - if i % 2 == 1: - cv2.line(img, s, e, color, thickness) - - -def drawpoint( - img: np.ndarray, - position: Union[Tuple[int], List[int]], - color: List[int], - meters_per_px: float, - pad: float = 0.3, -) -> None: - point_padding = int(pad / meters_per_px) - img[ - position[0] - point_padding : position[0] + point_padding + 1, - position[1] - point_padding : position[1] + point_padding + 1, - ] = color - - -def draw_triangle( - img: np.ndarray, - centroid: Union[Tuple[int], List[int]], - color: List[int], - meters_per_px: float, - pad: int = 0.35, -) -> None: - point_padding = int(pad / meters_per_px) - - # (Y, X) - left = (centroid[1] - point_padding, centroid[0] + point_padding) - right = (centroid[1] + point_padding, centroid[0] + point_padding) - top = (centroid[1], centroid[0] - point_padding) - cv2.drawContours(img, [np.array([left, right, top])], 0, color, -1) - - -def draw_reference_path( - img: np.ndarray, - sim: Simulator, - episode: VLNEpisode, - map_resolution: int, - meters_per_px: float, -) -> None: - """Draws lines between each waypoint in the reference path.""" - shortest_path_points = [ - habitat_maps.to_grid( - p[2], - p[0], - img.shape[0:2], - sim, - )[::-1] - for p in episode.reference_path - ] - - pt_from = None - for i, pt_to in enumerate(shortest_path_points): - - if i != 0: - drawline( - img, - (pt_from[0], pt_from[1]), - (pt_to[0], pt_to[1]), - MAP_SHORTEST_PATH_WAYPOINT, - thickness=int(0.4 * map_resolution / MAP_THICKNESS_SCALAR), - style="dashed", - gap=10, - ) - pt_from = pt_to - - for pt in shortest_path_points: - drawpoint(img, (pt[1], pt[0]), MAP_SHORTEST_PATH_WAYPOINT, meters_per_px) - - -def draw_straight_shortest_path_points( - img: np.ndarray, - sim: Simulator, - map_resolution: int, - shortest_path_points: List[List[float]], -) -> None: - """Draws the shortest path from start to goal assuming a standard - discrete action space. - """ - shortest_path_points = [habitat_maps.to_grid(p[2], p[0], img.shape[0:2], sim)[::-1] for p in shortest_path_points] - - habitat_maps.draw_path( - img, - [(p[1], p[0]) for p in shortest_path_points], - MAP_SHORTEST_PATH_WAYPOINT, - int(0.4 * map_resolution / MAP_THICKNESS_SCALAR), - ) - - -def draw_source_and_target(img: np.ndarray, sim: Simulator, episode: VLNEpisode, meters_per_px: float) -> None: - s_x, s_y = habitat_maps.to_grid( - episode.start_position[2], - episode.start_position[0], - img.shape[0:2], - sim, - ) - drawpoint(img, (s_x, s_y), MAP_SOURCE_POINT_INDICATOR, meters_per_px) - - # mark target point - t_x, t_y = habitat_maps.to_grid( - episode.goals[0].position[2], - episode.goals[0].position[0], - img.shape[0:2], - sim, - ) - drawpoint(img, (t_x, t_y), MAP_TARGET_POINT_INDICATOR, meters_per_px) - - -def draw_waypoint_prediction( - img: np.ndarray, - waypoint: Union[Tuple[float], List[float]], - meters_per_px: float, - bounds: Dict[str, Tuple[float]], -) -> None: - w_x, w_y = static_to_grid(waypoint[1], waypoint[0], img.shape[0:2], bounds) - if w_x < img.shape[0] and w_x > 0 and w_y < img.shape[1] and w_y > 0: - draw_triangle(img, (w_x, w_y), MAP_WAYPOINT_PREDICTION, meters_per_px) - - -def draw_oracle_waypoint( - img: np.ndarray, - waypoint: Union[Tuple[float], List[float]], - meters_per_px: float, - bounds: Dict[str, Tuple[float]], -) -> None: - w_x, w_y = static_to_grid(waypoint[1], waypoint[0], img.shape[0:2], bounds) - draw_triangle(img, (w_x, w_y), MAP_ORACLE_WAYPOINT, meters_per_px, pad=0.2) - - -def get_nearest_node(graph: nx.Graph, current_position: List[float]) -> str: - """Determine the closest MP3D node to the agent's start position as given - by a [x,z] position vector. - Returns: - node ID - """ - nearest = None - dist = float("inf") - for node in graph: - node_pos = graph.nodes[node]["position"] - node_pos = np.take(node_pos, (0, 2)) - cur_dist = np.linalg.norm(np.array(node_pos) - np.array(current_position), ord=2) - if cur_dist < dist: - dist = cur_dist - nearest = node - return nearest - - -def update_nearest_node(graph: nx.Graph, nearest_node: str, current_position: np.ndarray) -> str: - """Determine the closest MP3D node to the agent's current position as - given by a [x,z] position vector. The selected node must be reachable - from the previous MP3D node as specified in the nav-graph edges. - Returns: - node ID - """ - nearest = None - dist = float("inf") - - for node in [nearest_node] + [e[1] for e in graph.edges(nearest_node)]: - node_pos = graph.nodes[node]["position"] - node_pos = np.take(node_pos, (0, 2)) - cur_dist = np.linalg.norm(np.array(node_pos) - np.array(current_position), ord=2) - if cur_dist < dist: - dist = cur_dist - nearest = node - return nearest - - -def draw_mp3d_nodes( - img: np.ndarray, - sim: Simulator, - episode: VLNEpisode, - graph: nx.Graph, - meters_per_px: float, -) -> None: - n = get_nearest_node(graph, (episode.start_position[0], episode.start_position[2])) - starting_height = graph.nodes[n]["position"][1] - for node in graph: - pos = graph.nodes[node]["position"] - - # no obvious way to differentiate between floors. Use this for now. - if abs(pos[1] - starting_height) < 1.0: - r_x, r_y = habitat_maps.to_grid(pos[2], pos[0], img.shape[0:2], sim) - - # only paint if over a valid point - if img[r_x, r_y]: - drawpoint(img, (r_x, r_y), MAP_MP3D_WAYPOINT, meters_per_px) - - -from typing import Tuple - -import torch -from torch import Tensor - - -def image_resize( - img: Tensor, - size: Tuple[int, int], - channels_last: bool = False, - interpolation_mode: str = "area", -) -> torch.Tensor: - """Resizes an img. - - Args: - img: the array object that needs to be resized (HWC) or (NHWC) - size: the size that you want - channels: a boolean that channel is the last dimension - Returns: - The resized array as a torch tensor. - """ - img = torch.as_tensor(img) - no_batch_dim = len(img.shape) == 3 - if len(img.shape) < 3 or len(img.shape) > 5: - raise NotImplementedError() - if no_batch_dim: - img = img.unsqueeze(0) # Adds a batch dimension - if channels_last: - if len(img.shape) == 4: - # NHWC -> NCHW - img = img.permute(0, 3, 1, 2) - else: - # NDHWC -> NDCHW - img = img.permute(0, 1, 4, 2, 3) - - img = torch.nn.functional.interpolate(img.float(), size=size, mode=interpolation_mode).to(dtype=img.dtype) - if channels_last: - if len(img.shape) == 4: - # NCHW -> NHWC - img = img.permute(0, 2, 3, 1) - else: - # NDCHW -> NDHWC - img = img.permute(0, 1, 3, 4, 2) - if no_batch_dim: - img = img.squeeze(dim=0) # Removes the batch dimension - return img diff --git a/internnav/evaluator/__init__.py b/internnav/evaluator/__init__.py index 88393e5..35968e8 100644 --- a/internnav/evaluator/__init__.py +++ b/internnav/evaluator/__init__.py @@ -1,4 +1,12 @@ from internnav.evaluator.base import Evaluator -from internnav.evaluator.vln_multi_evaluator import VlnMultiEvaluator +from internnav.evaluator.distributed_base import DistributedEvaluator +from internnav.evaluator.vln_distributed_evaluator import VLNDistributedEvaluator -__all__ = ['Evaluator', 'VlnMultiEvaluator'] +# register habitat +try: + import internnav.habitat_extensions # noqa: F401 # isort: skip +except Exception as e: + print(f"Warning: ({e}), Habitat Evaluation is not loaded in this runtime. Ignore this if not using Habitat.") + + +__all__ = ['Evaluator', 'DistributedEvaluator', 'VLNDistributedEvaluator', 'HabitatVLNEvaluator'] diff --git a/internnav/evaluator/distributed_base.py b/internnav/evaluator/distributed_base.py new file mode 100644 index 0000000..9df4e74 --- /dev/null +++ b/internnav/evaluator/distributed_base.py @@ -0,0 +1,186 @@ +import json +import os + +import numpy as np +import torch + +from internnav.configs.evaluator import EvalCfg +from internnav.env import Env +from internnav.evaluator import Evaluator +from internnav.utils.dist import ( + dist, + get_rank, + get_world_size, + init_distributed_mode, + is_dist_avail_and_initialized, +) + + +class DistributedEvaluator(Evaluator): + """ + Base class of distributed evaluators. + + Args: + eval_cfg (EvalCfg): Evaluation configuration + init_env (bool): Whether to initialize the environment + init_agent (bool): Whether to initialize the agent + """ + + def __init__(self, eval_cfg: EvalCfg, init_env: bool = True, init_agent: bool = True): + # distributed setting + if not eval_cfg.eval_settings.get('use_agent_server', False): + self.local_rank = init_distributed_mode( + dist_url=eval_cfg.eval_settings.get('dist_url', "env://"), + port=eval_cfg.eval_settings.get('port', 29529), + ) + else: + self.local_rank = 0 + np.random.seed(self.local_rank) + + self.rank = get_rank() + self.world_size = get_world_size() + self.output_path = eval_cfg.eval_settings.get("output_path") + + # habitat env also need rank to split dataset + eval_cfg.env.env_settings['rank'] = get_rank() + eval_cfg.env.env_settings['local_rank'] = self.local_rank + eval_cfg.env.env_settings['world_size'] = get_world_size() + + self.eval_config = eval_cfg + + if init_env: + self.env = Env.init(eval_cfg.env, eval_cfg.task) + + # -------- initialize agent config (either remote server or local agent) -------- + if init_agent: + if eval_cfg.eval_settings.get('use_agent_server', False): + assert not is_dist_avail_and_initialized(), "agent server requires single evaluator process." + # set agent port based on rank + from internnav.utils import AgentClient + + print(f"[R{self.rank}] Connecting to agent server at port {eval_cfg.agent.server_port}") + self.agent = AgentClient(eval_cfg.agent) + else: + from internnav.agent import Agent + + eval_cfg.agent.model_settings['local_rank'] = self.local_rank + self.agent = Agent.init(eval_cfg.agent) + + def eval(self): + """ + Uniform distributed evaluation pipeline: + + 1. Call subclass's eval_action() to get local per-episode tensors. + 2. Use dist all_gather (+ padding) to build global tensors for each metric. + 3. Call subclass's calc_metrics(global_metrics) to compute scalar metrics. + 4. Print + rank 0 writes result.json. + """ + local_metrics = self.eval_action() # dict[str, Tensor], each [N_local] + + if not local_metrics: + raise RuntimeError("eval_action() returned empty metrics dict.") + + first_tensor = next(iter(local_metrics.values())) + device = first_tensor.device + local_len = first_tensor.shape[0] + + world_size = get_world_size() + + # -------- 1) Handle non-distributed / world_size == 1 -------- + if world_size == 1: + global_metrics = {name: tensor.detach().cpu() for name, tensor in local_metrics.items()} + total_len = int(local_len) + else: + # -------- 2) Gather lengths from all ranks -------- + local_len_t = torch.tensor([local_len], dtype=torch.long, device=device) + len_list = [torch.zeros_like(local_len_t) for _ in range(world_size)] + dist.all_gather(len_list, local_len_t) + lens = torch.stack(len_list).cpu() # shape [world_size, 1] + lens = lens.view(-1) # [world_size] + max_len = int(lens.max().item()) + total_len = int(lens.sum().item()) + + # -------- 3) For each metric, pad + all_gather + unpad -------- + global_metrics = {} + for name, tensor in local_metrics.items(): + assert tensor.shape[0] == local_len, ( + f"Metric {name} length ({tensor.shape[0]}) " f"!= first metric length ({local_len})" + ) + + # pad to max_len on this rank + padded = torch.zeros( + max_len, + dtype=tensor.dtype, + device=device, + ) + padded[:local_len] = tensor + + # gather padded tensors from all ranks + gathered = [torch.zeros_like(padded) for _ in range(world_size)] + dist.all_gather(gathered, padded) + + # unpad & concat using true lengths + parts = [] + for rank in range(world_size): + cur_len = int(lens[rank].item()) + if cur_len > 0: + parts.append(gathered[rank][:cur_len]) + if parts: + global_tensor = torch.cat(parts, dim=0) + else: + # no episodes at all (edge case) + global_tensor = torch.empty(0, dtype=tensor.dtype) + + global_metrics[name] = global_tensor.detach().cpu() + + # -------- 4) Let subclass compute final metrics from global tensors -------- + result_all = self.calc_metrics(global_metrics) + result_all.setdefault("length", total_len) + + # -------- 5) Logging -------- + print(result_all) + if get_rank() == 0: + os.makedirs(self.output_path, exist_ok=True) + out_path = os.path.join(self.output_path, "result.json") + with open(out_path, "a") as f: + f.write(json.dumps(result_all) + "\n") + + return result_all + + # ================= ABSTRACT HOOKS ================= + + def eval_action(self) -> dict: + """ + Run evaluation on this rank and return per-episode metrics. + + Returns: + dict[str, torch.Tensor] + Example: + { + "sucs": tensor([0., 1., ...], device=...), + "spls": tensor([...]), + "oss": tensor([...]), + "nes": tensor([...]), + } + """ + raise NotImplementedError + + def calc_metrics(self, global_metrics: dict) -> dict: + """ + Compute final scalar metrics from global per-episode tensors. + + Args: + global_metrics : dict[str, torch.Tensor] + For each metric name, a 1-D CPU tensor with all episodes across all ranks. + Example: + { + "sucs": tensor([...], dtype=torch.float32), + "spls": tensor([...]), + ... + } + + Returns: + dict[str, float] + Final scalar metrics to log. + """ + raise NotImplementedError diff --git a/internnav/evaluator/habitat_vln_evaluator.py b/internnav/evaluator/habitat_vln_evaluator.py deleted file mode 100644 index 3bf4c54..0000000 --- a/internnav/evaluator/habitat_vln_evaluator.py +++ /dev/null @@ -1,811 +0,0 @@ -import argparse -import copy -import itertools -import json -import os -import random -import re -from collections import OrderedDict -from typing import Any - -import habitat -import numpy as np -import quaternion -import torch -import tqdm -from depth_camera_filtering import filter_depth -from habitat import Env -from habitat.config.default import get_agent_config -from habitat.config.default_structured_configs import ( - CollisionsMeasurementConfig, - FogOfWarConfig, - TopDownMapMeasurementConfig, -) -from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower -from habitat.utils.visualizations.utils import images_to_video, observations_to_image -from habitat_baselines.config.default import get_config as get_habitat_config -from omegaconf import OmegaConf -from PIL import Image, ImageDraw, ImageFont -from torch import Tensor -from transformers.image_utils import to_numpy_array - -from internnav.model.utils.vln_utils import ( - chunk_token, - image_resize, - open_image, - rho_theta, - split_and_clean, - traj_to_actions, -) -from internnav.utils.dist import * # noqa: F403 - -DEFAULT_IMAGE_TOKEN = "" - - -class VLNEvaluator: - def __init__( - self, - config_path: str, - split: str = "val_seen", - env_num: int = 1, - output_path: str = None, - model: Any = None, - processor: Any = None, - epoch: int = 0, - args: argparse.Namespace = None, - ): - self.args = args - self.device = torch.device('cuda') - self.split = split - self.env_num = env_num - self.save_video = args.save_video - self.output_path = output_path - self.epoch = epoch - self.config_path = config_path - self.config = get_habitat_config(config_path) - self.agent_config = get_agent_config(self.config.habitat.simulator) - self.sim_sensors_config = self.config.habitat.simulator.agents.main_agent.sim_sensors - - with habitat.config.read_write(self.config): - # self.config.habitat.task.measurements.success.success_distance=3.0 - self.config.habitat.dataset.split = self.split - self.config.habitat.task.measurements.update( - { - "top_down_map": TopDownMapMeasurementConfig( - map_padding=3, - map_resolution=1024, - draw_source=True, - draw_border=True, - draw_shortest_path=True, - draw_view_points=True, - draw_goal_positions=True, - draw_goal_aabbs=True, - fog_of_war=FogOfWarConfig( - draw=True, - visibility_dist=5.0, - fov=90, - ), - ), - "collisions": CollisionsMeasurementConfig(), - } - ) - - print(f"config = {type(self.config)}") - print(OmegaConf.to_yaml(self.config)) - - self._camera_height = self.sim_sensors_config.rgb_sensor.position[1] - self._min_depth = self.sim_sensors_config.depth_sensor.min_depth - self._max_depth = self.sim_sensors_config.depth_sensor.max_depth - - camera_fov_rad = np.deg2rad(self.sim_sensors_config.depth_sensor.hfov) - self._camera_fov = camera_fov_rad - self._fx = self._fy = self.sim_sensors_config.depth_sensor.width / (2 * np.tan(camera_fov_rad / 2)) - - self.model = model - self.processor = processor - - prompt = "You are an autonomous navigation assistant. Your task is to . Where should you go next to stay on track? Please output the next waypoint\'s coordinates in the image. Please output STOP when you have successfully completed the task." - answer = "" - self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}] - - self.conjunctions = [ - 'you can see ', - 'in front of you is ', - 'there is ', - 'you can spot ', - 'you are toward the ', - 'ahead of you is ', - 'in your sight is ', - ] - - self.actions2idx = OrderedDict( - { - 'STOP': [0], - "↑": [1], - "←": [2], - "→": [3], - "↓": [5], - } - ) - - self.objectnav_instructions = ["Search for the {target_object}."] - - self.num_frames = args.num_frames - self.num_future_steps = args.num_future_steps - self.num_history = args.num_history - - def preprocess_depth_image_v2( - self, depth_image, do_depth_scale=True, depth_scale=1000, target_height=None, target_width=None - ): - if target_height is None: - target_height = self.image_processor.crop_size['height'] # 384 - target_width = self.image_processor.crop_size['width'] # 384 - - resized_depth_image = depth_image.resize((target_width, target_height), Image.NEAREST) - - img = to_numpy_array(resized_depth_image) - if do_depth_scale: - img = img / depth_scale - - return img, (target_width, target_height) - - def get_intrinsic_matrix(self, sensor_cfg) -> np.ndarray: - width = sensor_cfg.width - height = sensor_cfg.height - fov = sensor_cfg.hfov - fx = (width / 2.0) / np.tan(np.deg2rad(fov / 2.0)) - fy = fx # Assuming square pixels (fx = fy) - cx = (width - 1.0) / 2.0 - cy = (height - 1.0) / 2.0 - - intrinsic_matrix = np.array( - [[fx, 0.0, cx, 0.0], [0.0, fy, cy, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] - ) - return intrinsic_matrix - - def preprocess_instrinsic(self, intrinsic, ori_size, target_size): # (V, 4, 4) (resize_shape) (h, w) - intrinsic = copy.deepcopy(intrinsic) - if len(intrinsic.shape) == 2: - intrinsic = intrinsic[None, :, :] # (1, 4, 4) or (B, 4, 4) - - intrinsic[:, 0] /= ori_size[0] / target_size[0] # width - intrinsic[:, 1] /= ori_size[1] / target_size[1] # height - - # for crop transform - intrinsic[:, 0, 2] -= (target_size[0] - target_size[1]) / 2 - - if intrinsic.shape[0] == 1: - intrinsic = intrinsic.squeeze(0) - - return intrinsic - - def get_axis_align_matrix(self): - ma = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) - return ma - - def xyz_yaw_to_tf_matrix(self, xyz: np.ndarray, yaw: float) -> np.ndarray: - x, y, z = xyz - transformation_matrix = np.array( - [ - [np.cos(yaw), -np.sin(yaw), 0, x], - [np.sin(yaw), np.cos(yaw), 0, y], - [0, 0, 1, z], - [0, 0, 0, 1], - ] - ) - return transformation_matrix - - def xyz_pitch_to_tf_matrix(self, xyz: np.ndarray, pitch: float) -> np.ndarray: - """Converts a given position and pitch angle to a 4x4 transformation matrix. - - Args: - xyz (np.ndarray): A 3D vector representing the position. - pitch (float): The pitch angle in radians for y axis. - Returns: - np.ndarray: A 4x4 transformation matrix. - """ - - x, y, z = xyz - transformation_matrix = np.array( - [ - [np.cos(pitch), 0, np.sin(pitch), x], - [0, 1, 0, y], - [-np.sin(pitch), 0, np.cos(pitch), z], - [0, 0, 0, 1], - ] - ) - return transformation_matrix - - def xyz_yaw_pitch_to_tf_matrix(self, xyz: np.ndarray, yaw: float, pitch: float) -> np.ndarray: - """Converts a given position and yaw, pitch angles to a 4x4 transformation matrix. - - Args: - xyz (np.ndarray): A 3D vector representing the position. - yaw (float): The yaw angle in radians. - pitch (float): The pitch angle in radians for y axis. - Returns: - np.ndarray: A 4x4 transformation matrix. - """ - x, y, z = xyz - rot1 = self.xyz_yaw_to_tf_matrix(xyz, yaw)[:3, :3] - rot2 = self.xyz_pitch_to_tf_matrix(xyz, pitch)[:3, :3] - transformation_matrix = np.eye(4) - transformation_matrix[:3, :3] = rot1 @ rot2 - transformation_matrix[:3, 3] = xyz - return transformation_matrix - - def pixel_to_gps(self, pixel, depth, intrinsic, tf_camera_to_episodic): - ''' - Args: - pixel: (2,) - [u, v] pixel coordinates - depth: (H, W) - depth image where depth[v, u] gives depth in meters - intrinsic: (4, 4) - camera intrinsic matrix - tf_camera_to_episodic: (4, 4) - transformation from camera to episodic frame - Returns: - (x, y): (x, y) coordinates in the episodic frame - ''' - v, u = pixel - z = depth[v, u] - print("depthhhhhhhhhhhhhh", z) - - x = (u - intrinsic[0, 2]) * z / intrinsic[0, 0] - y = (v - intrinsic[1, 2]) * z / intrinsic[1, 1] - point_camera = np.array([x, y, z, 1.0]) - - # Transform to episodic frame - point_episodic = tf_camera_to_episodic @ point_camera - point_episodic = point_episodic[:3] / point_episodic[3] - - x = point_episodic[0] - y = point_episodic[1] - - return (x, y) # same as habitat gps - - def config_env(self) -> Env: - env = Env(config=self.config) - # env.episodes = env.episodes[0:1] - return env - - def dot_matrix_two_dimensional( - self, - image_or_image_path, - save_path=None, - dots_size_w=8, - dots_size_h=8, - save_img=False, - font_path='fonts/arial.ttf', - pixel_goal=None, - ): - """ - takes an original image as input, save the processed image to save_path. Each dot is labeled with two-dimensional Cartesian coordinates (x,y). Suitable for single-image tasks. - control args: - 1. dots_size_w: the number of columns of the dots matrix - 2. dots_size_h: the number of rows of the dots matrix - """ - with open_image(image_or_image_path) as img: - if img.mode != 'RGB': - img = img.convert('RGB') - draw = ImageDraw.Draw(img, 'RGB') - - width, height = img.size - grid_size_w = dots_size_w + 1 - grid_size_h = dots_size_h + 1 - cell_width = width / grid_size_w - cell_height = height / grid_size_h - - font = ImageFont.truetype(font_path, width // 40) # Adjust font size if needed; default == width // 40 - - target_i = target_j = None - if pixel_goal is not None: - y_pixel, x_pixel = pixel_goal[0], pixel_goal[1] - # Validate pixel coordinates - if not (0 <= x_pixel < width and 0 <= y_pixel < height): - raise ValueError(f"pixel_goal {pixel_goal} exceeds image dimensions ({width}x{height})") - - # Convert to grid coordinates - target_i = round(x_pixel / cell_width) - target_j = round(y_pixel / cell_height) - - # Validate grid bounds - if not (1 <= target_i <= dots_size_w and 1 <= target_j <= dots_size_h): - raise ValueError( - f"pixel_goal {pixel_goal} maps to grid ({target_j},{target_i}), " - f"valid range is (1,1)-({dots_size_h},{dots_size_w})" - ) - - count = 0 - - for j in range(1, grid_size_h): - for i in range(1, grid_size_w): - x = int(i * cell_width) - y = int(j * cell_height) - - pixel_color = img.getpixel((x, y)) - # choose a more contrasting color from black and white - if pixel_color[0] + pixel_color[1] + pixel_color[2] >= 255 * 3 / 2: - opposite_color = (0, 0, 0) - else: - opposite_color = (255, 255, 255) - - if pixel_goal is not None and i == target_i and j == target_j: - opposite_color = (255, 0, 0) # Red for target - - circle_radius = width // 240 # Adjust dot size if needed; default == width // 240 - draw.ellipse( - [(x - circle_radius, y - circle_radius), (x + circle_radius, y + circle_radius)], - fill=opposite_color, - ) - - text_x, text_y = x + 3, y - count_w = count // dots_size_w - count_h = count % dots_size_w - label_str = f"({count_w+1},{count_h+1})" - draw.text((text_x, text_y), label_str, fill=opposite_color, font=font) - count += 1 - if save_img: - print(">>> dots overlaid image processed, stored in", save_path) - img.save(save_path) - return img - - def _pointnav( - self, - goal: np.ndarray, - depth: np.ndarray, - step_id: int, - robot_xy: np.ndarray, - robot_heading: float, - stop: bool = False, - ) -> Tensor: - ''' - Args: - goal (np.ndarray): goal position - stop (bool): whether to stop - Returns: - action: action tensor - ''' - - masks = torch.tensor([step_id != 0], dtype=torch.bool, device="cuda") - if not np.array_equal(goal, self._last_goal): - if np.linalg.norm(goal - self._last_goal) > 0.1: - self._pointnav_policy.reset() - print('Pointnav policy reset!') - masks = torch.zeros_like(masks) - self._last_goal = goal - rho, theta = rho_theta(robot_xy, robot_heading, goal) - rho_theta_tensor = torch.tensor([[rho, theta]], device="cuda", dtype=torch.float32) - obs_pointnav = { - "depth": image_resize( - depth, - (self._pointnav_depth_image_shape[0], self._pointnav_depth_image_shape[1]), - channels_last=True, - interpolation_mode="area", - ), - "pointgoal_with_gps_compass": rho_theta_tensor, - } - - if rho < self._pointnav_stop_radius and stop: - return 0 - action = self._pointnav_policy.act(obs_pointnav, masks, deterministic=True) - return action - - def eval_action(self, idx) -> None: # noqa: C901 - self.model.eval() - env = self.config_env() - scene_episode_dict = {} - for episode in env.episodes: - if episode.scene_id not in scene_episode_dict: - scene_episode_dict[episode.scene_id] = [] - scene_episode_dict[episode.scene_id].append(episode) - - intrinsic_matrix = self.get_intrinsic_matrix( - self.config.habitat.simulator.agents.main_agent.sim_sensors.rgb_sensor - ) - sucs, spls, oss, nes = [], [], [], [] - done_res = [] - - if os.path.exists(os.path.join(self.output_path, 'result.json')): - with open(os.path.join(self.output_path, 'result.json'), 'r') as f: - for line in f.readlines(): - res = json.loads(line) - done_res.append([res["scene_id"], res["episode_id"], res["episode_instruction"]]) - if get_rank() == 0: # noqa: F405 - sucs.append(res['success']) - spls.append(res['spl']) - oss.append(res['os']) - nes.append(res['ne']) - - for scene in sorted(scene_episode_dict.keys()): - episodes = scene_episode_dict[scene] - scene_id = scene.split('/')[-2] - print(f"scene_id = {scene_id}") - process_bar = tqdm.tqdm(range(len(episodes[idx :: self.env_num])), desc=f"scene {scene_id}") - for episode in episodes[idx :: self.env_num]: - episode_instruction = ( - episode.instruction.instruction_text - if 'objectnav' not in self.config_path - else episode.object_category - ) - print("episode start", episode_instruction) - episode_id = int(episode.episode_id) - if [scene_id, episode_id, episode_instruction] in done_res: - continue - - env.current_episode = episode - observations = env.reset() - - agent_state = env.sim.get_agent_state() - rotation = agent_state.rotation - translation = agent_state.position - rotation_matrix = quaternion.as_rotation_matrix(rotation) - transformation_matrix = np.eye(4) - transformation_matrix[:3, :3] = rotation_matrix - transformation_matrix[:3, 3] = translation - - agent = ShortestPathFollower(env.sim, 0.25, False) - - os.makedirs(os.path.join(self.output_path, f'check_sim_{self.epoch}'), exist_ok=True) - Image.fromarray(observations['rgb']).save( - os.path.join(self.output_path, f'check_sim_{self.epoch}', f'rgb_{idx}.jpg') - ) - - vis_frames = [] - step_id = 0 - - if self.save_video: - os.makedirs(os.path.join(self.output_path, f'vis_{self.epoch}', f'{scene_id}'), exist_ok=True) - initial_height = env.sim.get_agent_state().position[1] - - rgb_list = [] - action_seq = [] - output_ids = None - - goal = None - action = None - messages = [] - local_actions = [] - - while not env.episode_over and step_id <= 500: - rgb = observations["rgb"] - depth = observations["depth"] - x, y = observations["gps"] - camera_yaw = observations["compass"][0] - depth = filter_depth(depth.reshape(depth.shape[:2]), blur_type=None) - depth = depth * (self._max_depth - self._min_depth) + self._min_depth - depth = depth * 1000 - - agent_state = env.sim.get_agent_state() - height = agent_state.position[1] - initial_height - camera_position = np.array([x, -y, self._camera_height + height]) - tf_camera_to_episodic = ( - self.xyz_yaw_pitch_to_tf_matrix(camera_position, camera_yaw, np.deg2rad(30)) - @ self.get_axis_align_matrix() - ) - - image = Image.fromarray(rgb).convert('RGB') - save_raw_image = image.copy() - - save_dot = False - if action == 5: - look_down_image = image - save_raw_image = look_down_image.copy() - look_down_depth, resize_shape = self.preprocess_depth_image_v2( - Image.fromarray(depth.astype(np.uint16), mode='I;16'), - do_depth_scale=True, - depth_scale=1000, - target_height=224, - target_width=224, - ) - look_down_depth = torch.as_tensor(np.ascontiguousarray(look_down_depth)).float() - look_down_depth[look_down_depth > 5.0] = 5.0 - else: - image = image.resize((self.args.resize_w, self.args.resize_h)) - rgb_list.append(image) - - if self.args.mode == 'dual_system': - down_observations = env.step(5) - down_observations = env.step(5) - - look_down_image = Image.fromarray(down_observations["rgb"]).convert('RGB') - depth = down_observations["depth"] - depth = filter_depth(depth.reshape(depth.shape[:2]), blur_type=None) - depth = depth * (self._max_depth - self._min_depth) + self._min_depth - depth = depth * 1000 - look_down_depth, resize_shape = self.preprocess_depth_image_v2( - Image.fromarray(depth.astype(np.uint16), mode='I;16'), - do_depth_scale=True, - depth_scale=1000, - target_height=224, - target_width=224, - ) - look_down_depth = torch.as_tensor(np.ascontiguousarray(look_down_depth)).float() - look_down_depth[look_down_depth > 5.0] = 5.0 - - env.step(4) - env.step(4) - - info = env.get_metrics() - - if len(action_seq) == 0 and goal is None: - if action != 5: - sources = copy.deepcopy(self.conversation) - sources[0]["value"] = sources[0]["value"].replace( - '.', episode.instruction.instruction_text[:-1] - ) - cur_images = rgb_list[-1:] - if step_id == 0: - history_id = [] - else: - history_id = np.unique( - np.linspace(0, step_id - 1, self.num_history, dtype=np.int32) - ).tolist() - placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id) - sources[0]["value"] += f' These are your historical observations: {placeholder}.' - - history_id = sorted(history_id) - print('history_idddddddd', step_id, history_id) - input_images = [rgb_list[i] for i in history_id] + cur_images - input_img_id = 0 - else: - assert action == 5 - sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}] - input_images += [look_down_image] - messages.append( - {'role': 'assistant', 'content': [{'type': 'text', 'text': llm_outputs}]} # noqa: F405 - ) - input_img_id = -1 - - prompt = random.choice(self.conjunctions) + DEFAULT_IMAGE_TOKEN - sources[0]["value"] += f" {prompt}." - print('sources', step_id, sources) - prompt_instruction = copy.deepcopy(sources[0]["value"]) - parts = split_and_clean(prompt_instruction) - - content = [] - for i in range(len(parts)): - if parts[i] == "": - content.append({"type": "image", "image": input_images[input_img_id]}) - input_img_id += 1 - else: - content.append({"type": "text", "text": parts[i]}) - - messages.append({'role': 'user', 'content': content}) - - print('step_id', step_id, 'messages:', messages) - - text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - - inputs = self.processor(text=[text], images=input_images, return_tensors="pt").to( - self.model.device - ) - - with torch.no_grad(): - output_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False) - - llm_outputs = self.processor.tokenizer.decode( - output_ids[0][inputs.input_ids.shape[1] :], skip_special_tokens=True - ) - print('step_id:', step_id, 'output text:', llm_outputs) - - if bool(re.search(r'\d', llm_outputs)): - forward_action = 0 - coord = [int(c) for c in re.findall(r'\d+', llm_outputs)] - pixel_goal = [int(coord[1]), int(coord[0])] - - goal = self.pixel_to_gps(pixel_goal, depth / 1000, intrinsic_matrix, tf_camera_to_episodic) - print('before', goal, depth.shape) - goal = (transformation_matrix @ np.array([-goal[1], 0, -goal[0], 1]))[:3] - - if not env.sim.pathfinder.is_navigable(np.array(goal)): - goal = np.array(env.sim.pathfinder.snap_point(np.array(goal))) - - # look down --> horizontal - env.step(4) - env.step(4) - - # Forking logic based on mode - if self.args.mode == 'system2': - action = agent.get_next_action(goal) - if action == 0: - goal = None - output_ids = None - action = 2 # random action - print('conduct a random action 2') - observations = env.step(action) - step_id += 1 - messages = [] - continue - else: # dual-system logic - local_actions = [] - pixel_values = inputs.pixel_values - image_grid_thw = torch.cat([thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0) - - with torch.no_grad(): - traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw) - - # prepocess align with navdp - image_dp = ( - torch.tensor(np.array(look_down_image.resize((224, 224)))).to(torch.bfloat16) / 255 - ) - pix_goal_image = copy.copy(image_dp) - images_dp = torch.stack([pix_goal_image, image_dp]).unsqueeze(0).to(self.device) - depth_dp = look_down_depth.unsqueeze(-1).to(torch.bfloat16) - pix_goal_depth = copy.copy(depth_dp) - depths_dp = torch.stack([pix_goal_depth, depth_dp]).unsqueeze(0).to(self.device) - - with torch.no_grad(): - dp_actions = self.model.generate_traj( - traj_latents, images_dp, depths_dp, use_async=True - ) - - random_choice = np.random.choice(dp_actions.shape[0]) - if self.args.continuous_traj: - action_list = traj_to_actions(dp_actions) - if len(action_list) < 8: - action_list += [0] * (8 - len(action_list)) - else: - action_list = chunk_token(dp_actions[random_choice]) - - local_actions = action_list - if len(local_actions) >= 4: - local_actions = local_actions[:4] - action = local_actions[0] - if action == 0: - goal = None - output_ids = None - action = 2 # random action - print('conduct a random action 2') - observations = env.step(action) - step_id += 1 - messages = [] - continue - - print('predicted goal', pixel_goal, goal, flush=True) - else: - action_seq = self.parse_actions(llm_outputs) - print('actions', action_seq, flush=True) - - if len(action_seq) != 0: - action = action_seq[0] - action_seq.pop(0) - elif goal is not None: - # Forking logic based on mode - if self.args.mode == 'system2': - action = agent.get_next_action(goal) - action = action.detach().cpu().numpy()[0] if isinstance(action, torch.Tensor) else action - action = action[0] if hasattr(action, "__len__") else action - else: # dual-system logic - if len(local_actions) == 0: - # navdp - local_actions = [] - image_dp = ( - torch.tensor(np.array(look_down_image.resize((224, 224)))).to(torch.bfloat16) / 255 - ) - - images_dp = torch.stack([pix_goal_image, image_dp]).unsqueeze(0).to(self.device) - depth_dp = look_down_depth.unsqueeze(-1).to(torch.bfloat16) - - depths_dp = torch.stack([pix_goal_depth, depth_dp]).unsqueeze(0).to(self.device) - with torch.no_grad(): - dp_actions = self.model.generate_traj( - traj_latents, images_dp, depths_dp, use_async=True - ) - - random_choice = np.random.choice(dp_actions.shape[0]) - if self.args.continuous_traj: - action_list = traj_to_actions(dp_actions) - if len(action_list) < 8: - action_list += [0] * (8 - len(action_list)) - else: - action_list = chunk_token(dp_actions[random_choice]) - print("first action_list", action_list) - - local_actions = action_list - if len(local_actions) >= 4: - local_actions = local_actions[:4] - # if len(local_actions) >= 2: - # local_actions = local_actions[:2] - - print("local_actions", local_actions) - - action = local_actions.pop(0) - # navdp - else: - action = local_actions.pop(0) - - forward_action += 1 - print('forward_action', forward_action, flush=True) - if forward_action > 8: - goal = None - output_ids = None - messages = [] - step_id += 1 - forward_action = 0 - local_actions = [] - continue - if action == 0: - goal = None - output_ids = None - messages = [] - step_id += 1 - forward_action = 0 - local_actions = [] - continue - else: - action = 0 - - if info['top_down_map'] is not None: - if save_dot: - save_raw_image = self.dot_matrix_two_dimensional( - save_raw_image, save_img=False, save_path=f'test_{step_id}.jpg', pixel_goal=pixel_goal - ) - frame = observations_to_image({'rgb': np.asarray(save_raw_image)}, info) - vis_frames.append(frame) - - print("step_id", step_id, "action", action) - - if action == 5: - env.step(action) - observations = env.step(action) - else: - observations = env.step(action) - step_id += 1 - messages = [] - - process_bar.update(1) - - metrics = env.get_metrics() - if self.save_video: - images_to_video( - vis_frames, - os.path.join(self.output_path, f'vis_{self.epoch}', f'{scene_id}'), - f'{episode_id:04d}', - fps=6, - quality=9, - ) - vis_frames.clear() - sucs.append(metrics['success']) - spls.append(metrics['spl']) - oss.append(metrics['oracle_success']) - nes.append(metrics["distance_to_goal"]) - print( - f"scene_episode {scene_id}_{episode_id:04d} success: {metrics['success']}, spl: {metrics['spl']}, os: {metrics['oracle_success']}, ne: {metrics['distance_to_goal']}" - ) - - result = { - "scene_id": scene_id, - "episode_id": episode_id, - "success": metrics["success"], - "spl": metrics["spl"], - "os": metrics['oracle_success'], - "ne": metrics["distance_to_goal"], - "steps": step_id, - "episode_instruction": episode_instruction, - } - - with open(os.path.join(self.output_path, 'result.json'), 'a') as f: - f.write(json.dumps(result) + "\n") - env.close() - return ( - torch.tensor(sucs).to(self.device), - torch.tensor(spls).to(self.device), - torch.tensor(oss).to(self.device), - torch.tensor(nes).to(self.device), - torch.tensor(len(sucs)).to(self.device), - ) - - def parse_actions(self, output): - action_patterns = '|'.join(re.escape(action) for action in self.actions2idx) - # import ipdb; ipdb.set_trace() - regex = re.compile(action_patterns) - matches = regex.findall(output) - actions = [self.actions2idx[match] for match in matches] - actions = itertools.chain.from_iterable(actions) - return list(actions) - - def preprocess_qwenvl(self, source): - prompt = random.choice(self.conjunctions) + DEFAULT_IMAGE_TOKEN - if len(source[0]["value"]) != 0: - source[0]["value"] += f" {prompt}." - else: - source[0]["value"] = f"{prompt}." # Please output the next waypoint\'s coordinates in the image." - return source diff --git a/internnav/evaluator/utils/common.py b/internnav/evaluator/utils/common.py index f855014..08dde7e 100644 --- a/internnav/evaluator/utils/common.py +++ b/internnav/evaluator/utils/common.py @@ -1,9 +1,5 @@ -import copy -import gzip -import json import math import os -from collections import defaultdict import numpy as np from PIL import Image, ImageDraw @@ -133,143 +129,6 @@ def check_is_on_track( return True -def has_stairs(item, height_threshold=0.3): - has_stairs = False - if 'stair' in item['instruction']['instruction_text']: - latest_height = item['reference_path'][0][-1] - for index in range(1, len(item['reference_path'])): - position = item['reference_path'][index] - if abs(position[-1] - latest_height) >= height_threshold: - has_stairs = True - break - else: - latest_height = position[-1] - return has_stairs - - -def different_height(item): - different_height = False - paths = item['reference_path'] - for path_idx in range(len(paths) - 1): - if abs(paths[path_idx + 1][2] - paths[path_idx][2]) > 0.3: - different_height = True - break - return different_height - - -def transform_rotation_z_90degrees(rotation): - z_rot_90 = [np.cos(np.pi / 4), 0, 0, np.sin(np.pi / 4)] # 90 degrees = pi/2 radians - w1, x1, y1, z1 = rotation - w2, x2, y2, z2 = z_rot_90 - revised_rotation = [ - w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, # w - w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, # x - w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, # y - w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, # z - ] - return revised_rotation - - -def load_data(dataset_root_dir, split, filter_same_trajectory=True, filter_stairs=True, dataset_type='mp3d'): - with gzip.open(os.path.join(dataset_root_dir, split, f"{split}.json.gz"), 'rt', encoding='utf-8') as f: - data = json.load(f)['episodes'] - - if dataset_type == 'mp3d': - scenes = list(set([x['scene_id'] for x in data])) # e.g. 'mp3d/zsNo4HB9uLZ/zsNo4HB9uLZ.glb' - elif dataset_type == 'kujiale': - scenes = list(set([x['scan'] for x in data])) - else: - raise Exception(f"Unsupported dataset type {dataset_type}, please update cfg to contain valid dataset_type") - scenes.sort() - new_data = {} - for scene in scenes: - if dataset_type == 'mp3d': - scene_data = [x for x in data if x['scene_id'] == scene] - scan = scene.split('/')[1] # e.g. 'zsNo4HB9uLZ' - else: - scene_data = [x for x in data if x['scan'] == scene] - scan = scene - new_scene_data = [] - for item in scene_data: - new_item = copy.deepcopy(item) - new_item['scan'] = scan - new_item['original_start_position'] = item['start_position'] - new_item['original_start_rotation'] = item['start_rotation'] - if dataset_type == 'mp3d': - x, z, y = item['start_position'] - new_item['start_position'] = [x, -y, z] - r1, r2, r3, r4 = item['start_rotation'] - new_item['start_rotation'] = transform_rotation_z_90degrees([-r4, r1, r3, -r2]) - new_item['reference_path'] = [[x, -y, z] for x, z, y in item['reference_path']] - new_scene_data.append(new_item) - - new_data[scan] = new_scene_data - - data = copy.deepcopy(new_data) - new_data = defaultdict(list) - - # filter_same_trajectory - if filter_same_trajectory: - total_count = 0 - remaining_count = 0 - trajectory_list = [] - for scan, data_item in data.items(): - for item in data_item: - total_count += 1 - if item['trajectory_id'] in trajectory_list: - continue - remaining_count += 1 - trajectory_list.append(item['trajectory_id']) - new_data[scan].append(item) - log.info(f'[split:{split}]filter_same_trajectory remain: [ {remaining_count} / {total_count} ]') - data = new_data - new_data = defaultdict(list) - - if filter_stairs: - total_count = 0 - remaining_count = 0 - for scan, data_item in data.items(): - for item in data_item: - total_count += 1 - if has_stairs(item) or different_height(item): - continue - remaining_count += 1 - new_data[scan].append(item) - log.info(f'[split:{split}]filter_stairs remain: [ {remaining_count} / {total_count} ]') - data = new_data - - return data - - -def load_scene_usd(mp3d_data_dir, scan): - """Load scene USD based on the scan""" - from internutopia.core.util import is_in_container - - find_flag = False - for root, dirs, files in os.walk(os.path.join(mp3d_data_dir, scan)): - target_file_name = 'fixed_docker.usd' if is_in_container() else 'fixed.usd' - for file in files: - if file == target_file_name: - scene_usd_path = os.path.join(root, file) - find_flag = True - break - if find_flag: - break - if not find_flag: - log.error('Scene USD not found for scan %s', scan) - return None - return scene_usd_path - - -def load_kujiale_scene_usd(kujiale_iros_data_dir, scan): - """Load scene USD based on the scan""" - scene_usd_path = os.path.join(kujiale_iros_data_dir, scan, f'{scan}.usda') - if not os.path.exists(scene_usd_path): - log.error('Scene USD not found for scan %s', scan) - return None - return scene_usd_path - - def get_new_position_and_rotation(robot_position, robot_rotation, action): from omni.isaac.core.utils.rotations import ( euler_angles_to_quat, @@ -605,7 +464,7 @@ def obs_to_image(obs_lst, action, output_path: str, reference_path, normalize: b topdown_array = first_obs['topdown_rgb'] # draw array on rgb array - rgb_array = draw_action_pil(rgb_array, action) + rgb_array = cv2.resize(draw_action_pil(rgb_array, action), (256, 256)) # draw trajectory on depth topdown_array = crop(draw_trajectory(topdown_array, obs_lst, reference_path)) diff --git a/internnav/evaluator/utils/data_collector.py b/internnav/evaluator/utils/data_collector.py index f721d82..0cc9117 100644 --- a/internnav/evaluator/utils/data_collector.py +++ b/internnav/evaluator/utils/data_collector.py @@ -7,12 +7,14 @@ class DataCollector: - def __init__(self, lmdb_path): + def __init__(self, lmdb_path, rank=0, world_size=1): if not os.path.exists(lmdb_path): os.makedirs(lmdb_path) self.lmdb_path = lmdb_path self.episode_total_data = [] self.actions = [] + self.rank = rank + self.world_size = world_size def collect_observation(self, rgb, depth, step, process, camera_pose, robot_pose): from omni.isaac.core.utils.rotations import quat_to_euler_angles @@ -104,7 +106,12 @@ def save_sample_data(self, key, result, instruction): if result != 'success': finish_flag = 'fail' lmdb_file = os.path.join(self.lmdb_path, 'sample_data.lmdb') - database = lmdb.open(lmdb_file, map_size=1 * 1024 * 1024 * 1024 * 1024, max_dbs=0) + database = lmdb.open( + lmdb_file, + map_size=1 * 1024 * 1024 * 1024 * 1024, + max_dbs=0, + lock=True, + ) with database.begin(write=True) as txn: encode_key = key.encode() episode_datas = self.merge_data(self.episode_total_data, self.actions) @@ -126,9 +133,10 @@ def save_eval_result(self, key, result, info): if result != 'success': finish_flag = 'fail' database_write = lmdb.open( - f'{self.lmdb_path}/sample_data.lmdb', + f'{self.lmdb_path}/sample_data{self.rank}.lmdb', map_size=1 * 1024 * 1024 * 1024 * 1024, max_dbs=0, + lock=True, ) with database_write.begin(write=True) as txn: key_write = key.encode() diff --git a/internnav/evaluator/utils/dataset.py b/internnav/evaluator/utils/result_logger.py similarity index 67% rename from internnav/evaluator/utils/dataset.py rename to internnav/evaluator/utils/result_logger.py index c4f224c..e3826a6 100644 --- a/internnav/evaluator/utils/dataset.py +++ b/internnav/evaluator/utils/result_logger.py @@ -1,112 +1,16 @@ -import os -import sys +import collections import json +import os import lmdb import msgpack_numpy from internnav import PROJECT_ROOT_PATH -from internnav.evaluator.utils.common import load_data -from internnav.evaluator.utils.config import get_lmdb_path, get_lmdb_prefix - from internnav.configs.evaluator import EvalDatasetCfg -from .config import Config +from internnav.env.utils.episode_loader.dataset_utils import load_data +from internnav.evaluator.utils.config import get_lmdb_path - -def split_data(dataset_cfg: EvalDatasetCfg): - if isinstance(dataset_cfg.dataset_settings, dict): - config = Config(**dataset_cfg.dataset_settings) - run_type = config.run_type - split_number = 1 # config.total_rank - run_type = run_type - name = config.task_name - split_data_types = config.split_data_types - base_data_dir = config.base_data_dir - filter_stairs = config.filter_stairs - - print(f'run_type:{run_type}') - print(f'name:{name}') - print(f'split_data_types:{split_data_types}') - prefix = get_lmdb_prefix(run_type) - if run_type == 'eval': - filter_same_trajectory = False - elif run_type == 'sample': - filter_same_trajectory = True - else: - print(f'unknown run_type:{run_type}') - sys.exit() - - lmdb_path = get_lmdb_path(name) - # get all data - path_key_map = {} - count = 0 - - dataset_type = dataset_cfg.dataset_type - for split_data_type in split_data_types: - data_map = load_data( - base_data_dir, - split_data_type, - filter_same_trajectory=filter_same_trajectory, - filter_stairs=filter_stairs, - dataset_type=dataset_type, - ) - for scan, path_list in data_map.items(): - path_key_list = [] - for path in path_list: - trajectory_id = path['trajectory_id'] - episode_id = path['episode_id'] - path_key = f'{trajectory_id}_{episode_id}' - path_key_list.append(path_key) - path_key_map[scan] = path_key_list - count += len(path_key_list) - - print(f'TOTAL:{count}') - - # split rank - rank_map = {} - split_length = count // split_number - index = -1 - for scan, path_key_list in path_key_map.items(): - for path_key in path_key_list: - index += 1 - rank = index // split_length - if rank >= split_number: - rank = split_number - 1 - rank_map[path_key] = rank - - ranked_data = {} - for i in range(split_number): - filtered_path_key_map = {} - for scan, path_key_list in path_key_map.items(): - filtered_list = [] - for path_key in path_key_list: - if rank_map[path_key] == i: - filtered_list.append(path_key) - if len(filtered_list) > 0: - filtered_path_key_map[scan] = filtered_list - ranked_data[i] = filtered_path_key_map - - for rank, path_key_map in ranked_data.items(): - count = 0 - for scan, path_key_list in path_key_map.items(): - count += len(path_key_list) - print(f'[rank:{rank}][scan:{scan}][count:{len(path_key_list)}]') - print(f'[rank:{rank}][count:{count}]') - - if not os.path.exists(lmdb_path): - os.makedirs(lmdb_path) - database = lmdb.open( - f'{lmdb_path}/sample_data.lmdb', - map_size=1 * 1024 * 1024 * 1024 * 1024, - max_dbs=0, - ) - with database.begin(write=True) as txn: - for rank, path_key_map in ranked_data.items(): - key = f'{prefix}_{rank}'.encode() - value = msgpack_numpy.packb(path_key_map, use_bin_type=True) - txn.put(key, value) - print(f'finish [key:{key}]') - database.close() +from .config import Config class ResultLogger: @@ -150,8 +54,6 @@ def get_split_map( return split_map def write_now_result_json(self): - # create log file - log_content = [] self.database_read = lmdb.open( f'{self.lmdb_path}/sample_data.lmdb', map_size=1 * 1024 * 1024 * 1024 * 1024, @@ -209,23 +111,23 @@ def write_now_result_json(self): reason_map[ret_type] = reason_map[ret_type] + 1 if success > 0: reason_map['reach_goal'] = reason_map['reach_goal'] + 1 - + if count == 0: continue - json_data[split]={} - json_data[split]['TL']=round((total_TL / count),4) - json_data[split]['NE']=round((total_NE / count),4) + json_data[split] = {} + json_data[split]['TL'] = round((total_TL / count), 4) + json_data[split]['NE'] = round((total_NE / count), 4) if 'fall' not in reason_map: reason_map['fall'] = 0 - json_data[split]['FR']=round((reason_map['fall'] / count),4) + json_data[split]['FR'] = round((reason_map['fall'] / count), 4) if 'stuck' in reason_map: - json_data[split]['StR']=round((reason_map['stuck'] / count),4) + json_data[split]['StR'] = round((reason_map['stuck'] / count), 4) else: - json_data[split]['StR']=0 - json_data[split]['OS']=round((total_osr / count),4) - json_data[split]['SR']=round((total_success / count),4) - json_data[split]['SPL']=round((total_spl / count),4) - json_data[split]['Count']=count + json_data[split]['StR'] = 0 + json_data[split]['OS'] = round((total_osr / count), 4) + json_data[split]['SR'] = round((total_success / count), 4) + json_data[split]['SPL'] = round((total_spl / count), 4) + json_data[split]['Count'] = count # write log content to file with open(f'{self.dataset_type}_result.json', 'w') as f: @@ -329,3 +231,93 @@ def log_print(content): f.write('\n'.join(log_content)) self.database_read.close() + + def finalize_all_results(self, rank, world_size): + # accumulator for all splits across all ranks + split_acc = {} + for split in self.split_map.keys(): + split_acc[split] = { + "total_TL": 0.0, + "total_NE": 0.0, + "total_osr": 0.0, + "total_success": 0.0, + "total_spl": 0.0, + "reason_map": collections.Counter({"reach_goal": 0}), + "count": 0, + } + + # loop over all ranks' lmdbs + for i in range(world_size): + lmdb_dir = f"{self.lmdb_path}/sample_data{i}.lmdb" + if not os.path.exists(lmdb_dir): + # this rank might not have produced a db; skip + continue + + env = lmdb.open( + lmdb_dir, + readonly=True, + lock=False, + max_readers=256, + ) + + for split, path_key_list in self.split_map.items(): + for path_key in path_key_list: + with env.begin() as txn: + value = txn.get(path_key.encode()) + if value is None: + continue + + data = msgpack_numpy.unpackb(value) + data["path_key"] = path_key + + acc = split_acc[split] + + TL = data["info"]["TL"] + NE = data["info"]["NE"] + if NE < 0: + NE = 0 + osr = data["info"]["osr"] + if osr < 0: + osr = 0 + success = data["info"]["success"] + spl = data["info"]["spl"] + + acc["total_TL"] += TL + acc["total_NE"] += NE + acc["total_osr"] += osr + acc["total_success"] += success + acc["total_spl"] += spl + acc["count"] += 1 + + ret_type = data.get("fail_reason", "") or "success" + acc["reason_map"][ret_type] += 1 + if success > 0: + acc["reason_map"]["reach_goal"] += 1 + + env.close() + + # build final json + json_data = {} + for split, acc in split_acc.items(): + count = acc["count"] + if count == 0: + continue + + reason_map = acc["reason_map"] + fall = reason_map.get("fall", 0) + stuck = reason_map.get("stuck", 0) + + json_data[split] = { + "TL": round(acc["total_TL"] / count, 4), + "NE": round(acc["total_NE"] / count, 4), + "FR": round(fall / count, 4), + "StR": round(stuck / count, 4), + "OS": round(acc["total_osr"] / count, 4), + "SR": round(acc["total_success"] / count, 4), + "SPL": round(acc["total_spl"] / count, 4), + "Count": count, + } + + # write log content to file + with open(f"{self.name}_result.json", "w") as f: + json.dump(json_data, f, indent=2, ensure_ascii=False) diff --git a/internnav/evaluator/vln_multi_evaluator.py b/internnav/evaluator/vln_distributed_evaluator.py similarity index 83% rename from internnav/evaluator/vln_multi_evaluator.py rename to internnav/evaluator/vln_distributed_evaluator.py index e573c7a..0ce616b 100644 --- a/internnav/evaluator/vln_multi_evaluator.py +++ b/internnav/evaluator/vln_distributed_evaluator.py @@ -1,4 +1,3 @@ -import sys from enum import Enum from pathlib import Path from time import time @@ -7,14 +6,12 @@ import numpy as np from internnav.configs.evaluator import EvalCfg -from internnav.evaluator.base import Evaluator +from internnav.evaluator import DistributedEvaluator, Evaluator from internnav.evaluator.utils.common import set_seed_model from internnav.evaluator.utils.config import get_lmdb_path from internnav.evaluator.utils.data_collector import DataCollector -from internnav.evaluator.utils.dataset import ResultLogger, split_data -from internnav.evaluator.utils.eval import generate_episode +from internnav.evaluator.utils.result_logger import ResultLogger from internnav.evaluator.utils.visualize_util import VisualizeUtil -from internnav.projects.dataloader.resumable import ResumablePathKeyDataloader from internnav.utils import common_log_util, progress_log_multi_util from internnav.utils.common_log_util import common_logger as log @@ -27,56 +24,51 @@ class runner_status_code(Enum): STOP = 4 -@Evaluator.register('vln_multi') -class VlnMultiEvaluator(Evaluator): +@Evaluator.register('vln_distributed') +class VLNDistributedEvaluator(DistributedEvaluator): def __init__(self, config: EvalCfg): + start_time = time() + self.task_name = config.task.task_name - if not Path(get_lmdb_path(self.task_name)).exists(): - split_data(config.dataset) self.result_logger = ResultLogger(config.dataset) - common_log_util.init(self.task_name) - self.dataloader = ResumablePathKeyDataloader(config.dataset.dataset_type, **config.dataset.dataset_settings) self.dataset_name = Path(config.dataset.dataset_settings['base_data_dir']).name - progress_log_multi_util.init(self.task_name, self.dataloader.size) - self.total_path_num = self.dataloader.size - progress_log_multi_util.progress_logger_multi.info( - f'start eval dataset: {self.task_name}, total_path: {self.dataloader.size}' # noqa: E501 - ) - # generate episode - episodes = generate_episode(self.dataloader, config) - if len(episodes) == 0: - log.info("No more episodes to evaluate. Episodes are saved in data/sample_episodes/") - sys.exit(0) - config.task.task_settings.update({'episodes': episodes}) + config.env.env_settings['dataset'] = config.dataset + + # vec env settings self.env_num = config.task.task_settings['env_num'] self.proc_num = ( config.env.env_settings['distribution_config']['proc_num'] if 'distribution_config' in config.env.env_settings else 1 ) - # check env_num and proc_num - # priority: reduce env_num first then reduce proc_num - while self.env_num > 1 and self.proc_num * self.env_num > self.total_path_num: - self.env_num -= 1 - log.info(f'dataset size is too small! Change env_num to {self.env_num}.') - while self.proc_num > 1 and self.proc_num * self.env_num > self.total_path_num: - self.proc_num -= 1 - log.info(f'dataset size is too small! Change proc_num to {self.proc_num}.') - # update + + # update config config.task.task_settings['env_num'] = self.env_num if 'distribution_config' in config.env.env_settings: config.env.env_settings['distribution_config']['proc_num'] = self.proc_num config.agent.model_settings.update({'env_num': self.env_num, 'proc_num': self.proc_num}) self.robot_name = config.task.robot_name + super().__init__(config) set_seed_model(0) - self.data_collector = DataCollector(self.dataloader.lmdb_path) + + common_log_util.init(self.task_name) + self.total_path_num = len(self.env.episodes) + progress_log_multi_util.init(self.task_name, self.total_path_num) + progress_log_multi_util.progress_logger_multi.info( + f'start eval dataset: {self.task_name}, total_path: {self.total_path_num}' # noqa: E501 + ) + self.data_collector = DataCollector(get_lmdb_path(self.task_name), rank=self.rank, world_size=self.world_size) self.robot_flash = config.task.robot_flash self.save_to_json = config.eval_settings['save_to_json'] self.vis_output = config.eval_settings['vis_output'] self.visualize_util = VisualizeUtil(self.task_name, fps=6) + end_time = time() + duration = round(end_time - start_time, 2) + log.info(f'[TIME] Env Init time: {duration}s') + @property def ignore_obs_attr(self): return [ @@ -96,7 +88,6 @@ def warm_up(self): action=[{self.robot_name: {'stand_still': []}} for _ in range(self.env_num * self.proc_num)] ) if obs[0][self.robot_name]['finish_action']: - print('get_obs') break return obs @@ -135,6 +126,7 @@ def _transform_action_batch(self, actions: List[Dict], flash=False): return transformed_actions def get_action(self, obs, action): + start_time = time() # process obs obs = np.array(obs) fake_obs_index = np.logical_or( @@ -150,6 +142,9 @@ def get_action(self, obs, action): # change warm_up action = np.array(action) action[self.runner_status == runner_status_code.WARM_UP] = {'h1': {'stand_still': []}} + end_time = time() + duration = round(end_time - start_time, 2) + log.info(f'[TIME] agent step time: {duration}s') return obs, action def _need_reset(self, terminated_ls): @@ -162,16 +157,13 @@ def _need_reset(self, terminated_ls): def env_step(self, action): start_time = time() - # stop_count = [0 for _ in range(self.env_num * self.sim_num)] + while True: # stop action maybe also need 50 steps self.runner_status[ np.logical_and(self.runner_status == runner_status_code.NORMAL, action == {'h1': {'stop': []}}) ] = runner_status_code.STOP - print(action) - t0 = time() obs, reward, terminated, truncated, info = self.env.step(action=action.tolist()) - print(f"inner one step time {time() - t0}") obs = self._obs_remove_robot_name(obs) finish_status = np.logical_or( np.array([ob['finish_action'] for ob in obs]), @@ -184,14 +176,21 @@ def env_step(self, action): ) or np.logical_and.reduce(np.array(finish_status)): self.runner_status[self.runner_status == runner_status_code.STOP] = runner_status_code.NORMAL break - if __debug__ and np.logical_or.reduce(np.array(finish_status)): - print(f'finish_status: {finish_status}') end_time = time() duration = round(end_time - start_time, 2) - log.info(f'env step time: {duration}s') + log.info(f'[TIME] Env Step time: {duration}s') return obs, terminated def terminate_ops(self, obs_ls, reset_infos, terminated_ls): + """ + 1. reset agent if finished warm up + 2. reset envs that are terminated + 3. start new trace log and visualize log + 4. return whether all envs are terminated + 5. return updated reset_infos + """ + start_time = time() + finish_warmup_ls = (self.runner_status == runner_status_code.WARM_UP) & [ob['finish_action'] for ob in obs_ls] if np.logical_or.reduce(finish_warmup_ls): self.agent.reset(np.where(finish_warmup_ls)[0].tolist()) @@ -225,9 +224,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls): result=obs['metrics'][list(obs['metrics'].keys())[0]][0]['fail_reason'], ) # json format result - if self.save_to_json: - self.result_logger.write_now_result_json() - self.result_logger.write_now_result() + self.result_logger.finalize_all_results(self.rank, self.world_size) self.runner_status[env_id] = runner_status_code.NOT_RESET log.info(f'env{env_id}: states switch to NOT_RESET.') # need this status to reset @@ -249,7 +246,6 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls): reset_infos = reset_infos.tolist() if np.logical_and.reduce(self.runner_status == runner_status_code.TERMINATED): - print('finished') return True, reset_infos for reset_info in new_reset_infos: if reset_info is None: @@ -263,12 +259,15 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls): self.visualize_util.trace_start( trajectory_id=self.now_path_key(reset_info), reference_path=reset_info.data['reference_path'] ) + + end_time = time() + duration = round(end_time - start_time, 2) + log.info(f'[TIME] Env Reset time: {duration}s') return False, reset_infos def eval(self): print('--- VlnMultiEvaluator start ---') obs, reset_info = self.env.reset() - print('obs:', obs) for info in reset_info: if info is None: continue @@ -293,11 +292,14 @@ def eval(self): self.runner_status[[info is None for info in reset_info]] = runner_status_code.TERMINATED while self.env.is_running(): - + # get action from agent obs, action = self.get_action(obs, action) + # step env obs, terminated = self.env_step(action) - env_term, reset_info = self.terminate_ops(obs, reset_info, terminated) - if env_term: + # terminate ops + env_terminate, reset_info = self.terminate_ops(obs, reset_info, terminated) + + if env_terminate: break # save step obs diff --git a/internnav/habitat_extensions/README.md b/internnav/habitat_extensions/README.md new file mode 100644 index 0000000..41dc200 --- /dev/null +++ b/internnav/habitat_extensions/README.md @@ -0,0 +1,133 @@ +# Habitat in InternNav + +This package adapts [Meta AI Habitat](https://aihabitat.org) environments and +metrics so they can be used from InternNav's evaluation framework. It provides +an environment wrapper, custom measurements, and evaluator implementations that +bridge Habitat simulations with InternNav agents and distributed evaluation +utilities. + +## Package structure + +``` +habitat_extensions/ +├── __init__.py +├── habitat_env.py +├── habitat_default_evaluator.py +├── habitat_vln_evaluator.py +└── measures.py +``` + +* `__init__.py` re-exports the public entry points for the environment and the + VLN evaluator so they can be imported as + `from internnav.habitat_extensions import HabitatEnv`. +* `habitat_env.py` implements the `Env` subclass that wraps Habitat's + `Env` object. It bootstraps episodes, handles sharding across distributed + ranks, and adapts Habitat's observations to InternNav's expectations. +* `habitat_default_evaluator.py` contains a lightweight evaluator that runs a + conventional Habitat agent inside the InternNav evaluator loop. +* `habitat_vln_evaluator.py` is the task-specific evaluator used for Vision- + and-Language Navigation (VLN). It loads InternNav vision-language models, + orchestrates inference, and logs results during distributed evaluation. +* `measures.py` registers additional Habitat measurements (path length, + oracle metrics, step counts) that are required by the evaluators. + + +## Habitat environment wrapper + +`HabitatEnv` is registered under the key `"habitat"` via the shared +`Env.register` decorator. When InternNav builds an environment from an +`EnvCfg`, the wrapper: + +1. Imports and instantiates the Habitat `Env` using the configuration object + provided in `env_settings['habitat_config']`. +2. Stores the distributed context (`local_rank`, `world_size`) and any output + directory override (`output_path`). +3. Pre-computes the episode list by grouping Habitat episodes by scene, + filtering completed episodes via `progress.json`, and sharding the remaining + work by rank. +4. Implements the standard reset/step/close/render accessors expected by the + InternNav `Env` base class while delegating to the underlying Habitat + simulator. + +This design keeps the Habitat-specific logic isolated from the rest of the +framework and ensures that distributed evaluation proceeds deterministically +across ranks. + +## Evaluation pipeline + +InternNav evaluators extend the shared `DistributedEvaluator` base class, which +handles distributed initialization, environment instantiation, metric +aggregation, and result logging. The Habitat integration provides two +implementations: + +### `HabitatVlnEvaluator` + +The VLN evaluator (`habitat_vln_evaluator.py`) is responsible for coordinating +model inference in Habitat scenes. + +* **Configuration:** During initialization the evaluator reads an `EvalCfg` + whose `env.env_settings['config_path']` points to a Habitat YAML file. The + config is loaded with Habitat's baseline utilities, sensor intrinsics are + cached, and custom measurements (`top_down_map`, `collisions`) are enabled. +* **Environment binding:** The Habitat configuration is injected back into the + `EnvCfg` so the shared `DistributedEvaluator` base class can create the + `HabitatEnv` wrapper with the correct settings. +* **Model loading:** Depending on `cfg.agent.model_settings.mode`, the evaluator + loads either the InternVLA dual-system model or a Qwen2.5-VL model using + Hugging Face Transformers. The processor is configured with left padding and + the model is moved to the rank-local GPU. +* **Episode loop:** + 1. `HabitatEnv.reset()` advances to the next episode and returns the first + observation. + 2. The evaluator reads episode metadata (scene, instruction) from Habitat, + constructs prompt messages, and collects RGB/depth history for the + language model. + 3. Visual inputs are prepared (resizing, optional look-down depth capture) and + depth maps are filtered through `filter_depth` to remove sensor noise. + 4. The evaluator queries the loaded model for the next action sequence, + translates model tokens to Habitat actions via `traj_to_actions`, and + steps the environment. + 5. Per-episode metrics (`success`, `SPL`, oracle success, navigation error) + are appended and checkpointed to `progress.json` for resumability. +* **Aggregation:** After all ranks finish, inherited utilities gather per-rank + tensors, compute global averages, and write `result.json` in + `output_path`. + +### `HabitatVlnEvaluator` (baseline) + +The default evaluator in `habitat_default_evaluator.py` offers a simpler loop +where a pre-built InternNav agent interacts with the Habitat environment. +InternNav's agent abstraction is reset with each new Habitat episode, and +per-step actions are produced via `agent.act()`. The evaluator records the same +metrics as the VLN evaluator, making it useful for baselines or sanity checks. + +## Custom Habitat measurements + +`measures.py` registers a suite of metrics with Habitat's registry so that they +are available in the Habitat configuration: + +* `PathLength`: cumulative Euclidean distance traveled by the agent. +* `OracleNavigationError`: minimum geodesic distance to the goal along the + trajectory. +* `OracleSuccess`: binary success metric derived from oracle navigation error + relative to a goal radius (default 3.0 meters). +* `OracleSPL`: best Success weighted by Path Length value observed during the + trajectory. +* `StepsTaken`: number of actions issued by the agent, including STOP. + +These metrics complement Habitat's built-in success and SPL scores, allowing +InternNav to report a richer set of statistics. + +## Extending the integration + +* **Adding evaluators:** Subclass `DistributedEvaluator`, supply + Habitat-specific initialization similar to `HabitatVlnEvaluator`, and + implement `eval_action` and `calc_metrics`. +* **Custom sensors or observations:** Augment the Habitat YAML configuration and + update `HabitatEnv` or the evaluator to consume the new observation keys. +* **Additional metrics:** Register new measures in `measures.py` and enable them + in the Habitat config via `config.habitat.task.measurements.update(...)`. + +By centralizing Habitat-specific logic in this package, InternNav can swap in +other simulators or extend Habitat support without touching the rest of the +training and evaluation stack. diff --git a/internnav/habitat_extensions/__init__.py b/internnav/habitat_extensions/__init__.py new file mode 100644 index 0000000..5906e1c --- /dev/null +++ b/internnav/habitat_extensions/__init__.py @@ -0,0 +1,2 @@ +from internnav.habitat_extensions.habitat_env import HabitatEnv +from internnav.habitat_extensions.habitat_vln_evaluator import HabitatVLNEvaluator diff --git a/internnav/habitat_extensions/habitat_default_evaluator.py b/internnav/habitat_extensions/habitat_default_evaluator.py new file mode 100644 index 0000000..c4df86a --- /dev/null +++ b/internnav/habitat_extensions/habitat_default_evaluator.py @@ -0,0 +1,153 @@ +import argparse +import sys + +sys.path.append('./src/diffusion-policy') + + +# Import for Habitat registry side effects — do not remove +import internnav.env.utils.habitat_extensions.measures # noqa: F401 +from internnav.configs.evaluator import EvalCfg +from internnav.evaluator import DistributedEvaluator, Evaluator + +try: + import habitat + from habitat.config.default import get_agent_config + from habitat.config.default_structured_configs import ( + CollisionsMeasurementConfig, + FogOfWarConfig, + TopDownMapMeasurementConfig, + ) + from habitat_baselines.config.default import get_config as get_habitat_config +except Exception as e: + print("Habitat Error:", e) + print("Habitat Evaluation is not loaded.") + + +DEFAULT_IMAGE_TOKEN = "" + + +@Evaluator.register('habitat_evaluator') +class HabitatDefaultEvaluator(DistributedEvaluator): + """ + A default evaluator class for running Habitat-based evaluations in a distributed environment. + + This evaluator is designed to work with the Habitat simulator and performs evaluation of + agents on local episodes. It provides metrics such as success rate (success), SPL (Success weighted by Path Length), + Oracle success rate (oracle_success), and the distance to the goal (distance_to_goal). + + Attributes: + save_video (bool): Whether to save video during the evaluation. + epoch (int): The current epoch of the evaluation process. + max_steps_per_episode (int): The maximum number of steps allowed per episode. + output_path (str): The path where the evaluation results are saved. + config (habitat.config.default.Config): The Habitat configuration used for the environment setup. + agent_config (habitat.config.default.AgentConfig): Configuration specific to the agent in the Habitat simulator. + sim_sensors_config (dict): Configuration for the sensors used by the agent in the simulation. + + Methods: + eval_action() -> dict: + Runs the local episodes and returns a dictionary of evaluation metrics such as success rate, + success weighted by path length (SPL), oracle success, and distance to the goal. + + calc_metrics(global_metrics: dict) -> dict: + Calculates the global evaluation metrics from the distributed results by aggregating local metrics. + """ + + def __init__(self, cfg: EvalCfg): + args = argparse.Namespace(**cfg.eval_settings) + self.args = args + self.save_video = args.save_video + self.epoch = args.epoch + self.max_steps_per_episode = args.max_steps_per_episode + self.output_path = args.output_path + + # create habitat config + self.config_path = cfg.env.env_settings['config_path'] + self.config = get_habitat_config(self.config_path) + self.agent_config = get_agent_config(self.config.habitat.simulator) + self.sim_sensors_config = self.config.habitat.simulator.agents.main_agent.sim_sensors + + with habitat.config.read_write(self.config): + self.config.habitat.task.measurements.update( + { + "top_down_map": TopDownMapMeasurementConfig( + map_padding=3, + map_resolution=1024, + draw_source=True, + draw_border=True, + draw_shortest_path=True, + draw_view_points=True, + draw_goal_positions=True, + draw_goal_aabbs=True, + fog_of_war=FogOfWarConfig( + draw=True, + visibility_dist=5.0, + fov=90, + ), + ), + "collisions": CollisionsMeasurementConfig(), + } + ) + cfg.env.env_settings['habitat_config'] = self.config + cfg.env.env_settings['output_path'] = self.output_path + + # init agent and env + super().__init__(cfg) + + def eval_action(self): + """ + Run local episodes on this rank. + + Returns dict[str, Tensor] on GPU (1D tensors of same length). + """ + sucs, spls, oss, nes = [], [], [], [] + env = self.env + + while env.is_running: + obs = env.reset() + if not env.is_running or obs is None: + break + + episode = env.env.current_episode + self.agent.reset(episode, env) + + done = False + step_id = 0 + while not done and step_id <= self.max_steps_per_episode: + action = self.agent.act(obs, env, info=None) + obs, reward, done, info = env.step(action) + step_id += 1 + + m = env.get_metrics() + sucs.append(m["success"]) + spls.append(m["spl"]) + oss.append(m["oracle_success"]) + nes.append(m["distance_to_goal"]) + + env.close() + return { + "sucs": sucs, # shape [N_local] + "spls": spls, # shape [N_local] + "oss": oss, # shape [N_local] + "nes": nes, # shape [N_local] + } + + def calc_metrics(self, global_metrics: dict) -> dict: + """ + global_metrics["sucs"] etc. are global 1-D CPU tensors with all episodes. + """ + sucs_all = global_metrics["sucs"] + spls_all = global_metrics["spls"] + oss_all = global_metrics["oss"] + nes_all = global_metrics["nes"] + + # avoid /0 if no episodes + denom = max(len(sucs_all), 1) + + return { + "sucs_all": float(sucs_all.mean().item()) if denom > 0 else 0.0, + "spls_all": float(spls_all.mean().item()) if denom > 0 else 0.0, + "oss_all": float(oss_all.mean().item()) if denom > 0 else 0.0, + "nes_all": float(nes_all.mean().item()) if denom > 0 else 0.0, + # "length" will be filled by base class + } diff --git a/internnav/habitat_extensions/habitat_env.py b/internnav/habitat_extensions/habitat_env.py new file mode 100644 index 0000000..1b0f3f4 --- /dev/null +++ b/internnav/habitat_extensions/habitat_env.py @@ -0,0 +1,128 @@ +import json +import os +from typing import Any, Dict, List, Optional + +from internnav.configs.evaluator import EnvCfg, TaskCfg +from internnav.env import base + + +@base.Env.register('habitat') +class HabitatEnv(base.Env): + def __init__(self, env_config: EnvCfg, task_config: TaskCfg): + """ + env_settings include: + - habitat_config: loaded from get_habitat_config + - rank: int, rank index for sharding + - world_size: int, total number of ranks + """ + try: + from habitat import Env + except ImportError as e: + raise RuntimeError( + "Habitat modules could not be imported. " "Make sure both repositories are installed and on PYTHONPATH." + ) from e + + super().__init__(env_config, task_config) + + self.config = env_config.env_settings['habitat_config'] + self._env = Env(self.config) + + self.rank = env_config.env_settings.get('rank', 0) + self.world_size = env_config.env_settings.get('world_size', 1) + self._current_episode_index: int = 0 + self._last_obs: Optional[Dict[str, Any]] = None + + self.is_running = True + self.output_path = env_config.env_settings.get('output_path', './output') + + # generate episodes + # self._env.episodes = self._env.episodes[0:1] # for debug + self.episodes = self.generate_episodes() + # print(self.episodes) + + def generate_episodes(self) -> List[Any]: + """ + Generate list of episodes for the current split, already: + - grouped by scene + - filtered by done_res (the path is self.output_path/progress.json) + - sharded by (rank, world_size) + """ + all_episodes = [] + + # group episodes by scene + scene_episode_dict: Dict[str, List[Any]] = {} + for episode in self._env.episodes: + scene_episode_dict.setdefault(episode.scene_id, []).append(episode) + + # load done_res + done_res = set() + result_path = os.path.join(self.output_path, 'progress.json') + if os.path.exists(result_path): + with open(result_path, 'r') as f: + for line in f: + res = json.loads(line) + # only skip if current format has scene_id + if "scene_id" in res: + done_res.add((res["scene_id"], res["episode_id"])) + + # iterate scenes in order, collect all episodes + for scene in sorted(scene_episode_dict.keys()): + per_scene_eps = scene_episode_dict[scene] + scene_id = scene.split('/')[-2] + + # shard by rank index / world_size + for episode in per_scene_eps[self.rank :: self.world_size]: + episode_id = int(episode.episode_id) + if (scene_id, episode_id) in done_res: + continue + all_episodes.append(episode) + + return all_episodes + + def reset(self): + """ + load next episode and return first observation + """ + # no more episodes + if not (0 <= self._current_episode_index < len(self.episodes)): + self.is_running = False + return + + # Manually set to next episode in habitat + self._env.current_episode = self.episodes[self._current_episode_index] + self._current_episode_index += 1 + + # Habitat reset + self._last_obs = self._env.reset() + + return self._last_obs + + def step(self, action: List[Any]): + """ + step the environment with given action + + Args: action: List[Any], action for each env in the batch + + Return: obs, reward, done, info + """ + obs = self._env.step(action) + done = self._env.episode_over + info = self._env.get_metrics() + reward = info.get('reward', 0.0) + return obs, reward, done, info + + def close(self): + print('Habitat Env close') + self._env.close() + + def render(self): + self._env.render() + + def get_observation(self) -> Dict[str, Any]: + return self._env.get_observations() + + def get_metrics(self) -> Dict[str, Any]: + return self._env.get_metrics() + + def get_current_episode(self): + return self._env.current_episode diff --git a/internnav/habitat_extensions/habitat_vln_evaluator.py b/internnav/habitat_extensions/habitat_vln_evaluator.py new file mode 100644 index 0000000..2d552fd --- /dev/null +++ b/internnav/habitat_extensions/habitat_vln_evaluator.py @@ -0,0 +1,840 @@ +import argparse +import json +import os +import sys + +sys.path.append('./src/diffusion-policy') +import copy +import itertools +import random +import re +from collections import OrderedDict + +import habitat +import numpy as np +import quaternion +import torch +import tqdm +from depth_camera_filtering import filter_depth +from habitat.config.default import get_agent_config +from habitat.config.default_structured_configs import ( + CollisionsMeasurementConfig, + FogOfWarConfig, + TopDownMapMeasurementConfig, +) +from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower +from habitat.utils.visualizations.utils import images_to_video, observations_to_image +from habitat_baselines.config.default import get_config as get_habitat_config +from PIL import Image, ImageDraw, ImageFont +from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration +from transformers.image_utils import to_numpy_array + +from internnav.configs.evaluator import EvalCfg +from internnav.evaluator import DistributedEvaluator, Evaluator +from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM +from internnav.model.utils.vln_utils import ( + chunk_token, + open_image, + split_and_clean, + traj_to_actions, +) + +# Import for Habitat registry side effects — do not remove +import internnav.habitat_extensions.measures # noqa: F401 # isort: skip + +DEFAULT_IMAGE_TOKEN = "" + + +@Evaluator.register('habitat_vln') +class HabitatVLNEvaluator(DistributedEvaluator): + def __init__(self, cfg: EvalCfg): + args = argparse.Namespace(**cfg.eval_settings) + self.save_video = args.save_video + self.epoch = args.epoch + self.max_steps_per_episode = args.max_steps_per_episode + self.output_path = args.output_path + + # create habitat config + self.config_path = cfg.env.env_settings['config_path'] + self.config = get_habitat_config(self.config_path) + self.agent_config = get_agent_config(self.config.habitat.simulator) + self.sim_sensors_config = self.config.habitat.simulator.agents.main_agent.sim_sensors + + with habitat.config.read_write(self.config): + self.config.habitat.task.measurements.update( + { + "top_down_map": TopDownMapMeasurementConfig( + map_padding=3, + map_resolution=1024, + draw_source=True, + draw_border=True, + draw_shortest_path=True, + draw_view_points=True, + draw_goal_positions=True, + draw_goal_aabbs=True, + fog_of_war=FogOfWarConfig( + draw=True, + visibility_dist=5.0, + fov=90, + ), + ), + "collisions": CollisionsMeasurementConfig(), + } + ) + cfg.env.env_settings['habitat_config'] = self.config + cfg.env.env_settings['output_path'] = self.output_path + + # init agent and env + super().__init__(cfg, init_agent=False) + + # ------------------------------------- model ------------------------------------------ + self.model_args = argparse.Namespace(**cfg.agent.model_settings) + + processor = AutoProcessor.from_pretrained(self.model_args.model_path) + processor.tokenizer.padding_side = 'left' + + device = torch.device(f"cuda:{self.local_rank}") + if self.model_args.mode == 'dual_system': + model = InternVLAN1ForCausalLM.from_pretrained( + self.model_args.model_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map={"": device}, + ) + elif self.model_args.mode == 'system2': + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_args.model_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map={"": device}, + ) + else: + raise ValueError(f"Invalid mode: {self.model_args.mode}") + + model.eval() + self.device = device + + self.model = model + self.processor = processor + + # refactor: this part used in three places + prompt = "You are an autonomous navigation assistant. Your task is to . Where should you go next to stay on track? Please output the next waypoint\'s coordinates in the image. Please output STOP when you have successfully completed the task." + answer = "" + self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}] + + self.conjunctions = [ + 'you can see ', + 'in front of you is ', + 'there is ', + 'you can spot ', + 'you are toward the ', + 'ahead of you is ', + 'in your sight is ', + ] + + self.actions2idx = OrderedDict( + { + 'STOP': [0], + "↑": [1], + "←": [2], + "→": [3], + "↓": [5], + } + ) + + self.objectnav_instructions = ["Search for the {target_object}."] + + self.num_frames = self.model_args.num_frames + self.num_future_steps = self.model_args.num_future_steps + self.num_history = self.model_args.num_history + + self._camera_height = self.sim_sensors_config.rgb_sensor.position[1] + self._min_depth = self.sim_sensors_config.depth_sensor.min_depth + self._max_depth = self.sim_sensors_config.depth_sensor.max_depth + + camera_fov_rad = np.deg2rad(self.sim_sensors_config.depth_sensor.hfov) + self._camera_fov = camera_fov_rad + self._fx = self._fy = self.sim_sensors_config.depth_sensor.width / (2 * np.tan(camera_fov_rad / 2)) + + def eval_action(self): + """ + Run local episodes on this rank. + + Returns dict[str, Tensor] on GPU (1D tensors of same length). + """ + # Old behavior was something like: + # sucs, spls, oss, nes, ep_num = self.eval_action(self.rank) + # Now just implement the actual eval here and return dict. + + sucs, spls, oss, nes, _ = self._run_local_eval() + + return { + "sucs": sucs, # shape [N_local] + "spls": spls, # shape [N_local] + "oss": oss, # shape [N_local] + "nes": nes, # shape [N_local] + } + + def calc_metrics(self, global_metrics: dict) -> dict: + """ + global_metrics["sucs"] etc. are global 1-D CPU tensors with all episodes. + """ + sucs_all = global_metrics["sucs"] + spls_all = global_metrics["spls"] + oss_all = global_metrics["oss"] + nes_all = global_metrics["nes"] + + # avoid /0 if no episodes + denom = max(len(sucs_all), 1) + + return { + "sucs_all": float(sucs_all.mean().item()) if denom > 0 else 0.0, + "spls_all": float(spls_all.mean().item()) if denom > 0 else 0.0, + "oss_all": float(oss_all.mean().item()) if denom > 0 else 0.0, + "nes_all": float(nes_all.mean().item()) if denom > 0 else 0.0, + # "length" will be filled by base class + } + + def _run_local_eval(self) -> None: # noqa: C901 + """ + Run local evaluation on this rank. + + Important: if resuming from previous results, need to read from / write to "self.output_path/progress.json". + For each episode, save the result dict in jsonl format to that file. + In Env, the episodes are already filtered by this file, tasks that have the same (scene_id, episode_id) are skipped. + + + Returns + ------- + dict[str, Tensor]: + { + "sucs": [N_local], + "spls": [N_local], + "oss": [N_local], + "nes": [N_local], + } + """ + # Create / get env + # self.env = self.env # HabitatEnv from DistributedEvaluator + + sucs, spls, oss, nes = [], [], [], [] + self.model.eval() + + # resume from previous results + if os.path.exists(os.path.join(self.output_path, 'progress.json')): + with open(os.path.join(self.output_path, 'progress.json'), 'r') as f: + for line in f.readlines(): + res = json.loads(line) + if "scene_id" not in res: + print("This evaluation has already finished!") + return ( + torch.tensor(sucs).to(self.device), + torch.tensor(spls).to(self.device), + torch.tensor(oss).to(self.device), + torch.tensor(nes).to(self.device), + torch.tensor(len(sucs)).to(self.device), + ) + if self.rank == 0: # noqa: F405 + sucs.append(res['success']) + spls.append(res['spl']) + oss.append(res['os']) + nes.append(res['ne']) + + # Episode loop is now driven by env.reset() + env.is_running + process_bar = tqdm.tqdm(total=len(self.env.episodes), desc=f"Eval Epoch {self.epoch} Rank {self.rank}") + while self.env.is_running: + + # ------------ 1. Start of episode ------------ + observations = self.env.reset() + if not self.env.is_running or observations is None: + break + + # ---- episode meta (scene_id, episode_id, instruction) ---- + # we get it from the underlying habitat env + episode = self.env.get_current_episode() + scene_id = episode.scene_id.split('/')[-2] + episode_id = int(episode.episode_id) + episode_instruction = ( + episode.instruction.instruction_text if 'objectnav' not in self.config_path else episode.object_category + ) + print("episode start", episode_instruction) + + agent_state = self.env._env.sim.get_agent_state() + rotation = agent_state.rotation + translation = agent_state.position + rotation_matrix = quaternion.as_rotation_matrix(rotation) + transformation_matrix = np.eye(4) + transformation_matrix[:3, :3] = rotation_matrix + transformation_matrix[:3, 3] = translation + + agent = ShortestPathFollower(self.env._env.sim, 0.25, False) + + # save first frame per rank to validate sim quality + os.makedirs(os.path.join(self.output_path, f'check_sim_{self.epoch}'), exist_ok=True) + Image.fromarray(observations['rgb']).save( + os.path.join(self.output_path, f'check_sim_{self.epoch}', f'rgb_{self.rank}.jpg') + ) + + vis_frames = [] + step_id = 0 + + if self.save_video: + os.makedirs(os.path.join(self.output_path, f'vis_{self.epoch}', f'{scene_id}'), exist_ok=True) + initial_height = self.env._env.sim.get_agent_state().position[1] + + rgb_list = [] + action_seq = [] + output_ids = None + + goal = None + action = None + messages = [] + local_actions = [] + + done = False + + # ---------- 2. Episode step loop ----------- + while (not done) and (step_id <= self.max_steps_per_episode): + # refactor agent get action + rgb = observations["rgb"] + depth = observations["depth"] + x, y = observations["gps"] + camera_yaw = observations["compass"][0] + depth = filter_depth(depth.reshape(depth.shape[:2]), blur_type=None) + depth = depth * (self._max_depth - self._min_depth) + self._min_depth + depth = depth * 1000 + + agent_state = self.env._env.sim.get_agent_state() + height = agent_state.position[1] - initial_height + camera_position = np.array([x, -y, self._camera_height + height]) + tf_camera_to_episodic = ( + self.xyz_yaw_pitch_to_tf_matrix(camera_position, camera_yaw, np.deg2rad(30)) + @ self.get_axis_align_matrix() + ) + + image = Image.fromarray(rgb).convert('RGB') + save_raw_image = image.copy() + + save_dot = False + if action == 5: + look_down_image = image + save_raw_image = look_down_image.copy() + look_down_depth, resize_shape = self.preprocess_depth_image_v2( + Image.fromarray(depth.astype(np.uint16), mode='I;16'), + do_depth_scale=True, + depth_scale=1000, + target_height=224, + target_width=224, + ) + look_down_depth = torch.as_tensor(np.ascontiguousarray(look_down_depth)).float() + look_down_depth[look_down_depth > 5.0] = 5.0 + else: + image = image.resize((self.model_args.resize_w, self.model_args.resize_h)) + rgb_list.append(image) + + if self.model_args.mode == 'dual_system': + down_observations, _, done, _ = self.env.step(5) + down_observations, _, done, _ = self.env.step(5) + + look_down_image = Image.fromarray(down_observations["rgb"]).convert('RGB') + depth = down_observations["depth"] + depth = filter_depth(depth.reshape(depth.shape[:2]), blur_type=None) + depth = depth * (self._max_depth - self._min_depth) + self._min_depth + depth = depth * 1000 + look_down_depth, resize_shape = self.preprocess_depth_image_v2( + Image.fromarray(depth.astype(np.uint16), mode='I;16'), + do_depth_scale=True, + depth_scale=1000, + target_height=224, + target_width=224, + ) + look_down_depth = torch.as_tensor(np.ascontiguousarray(look_down_depth)).float() + look_down_depth[look_down_depth > 5.0] = 5.0 + + self.env.step(4) + self.env.step(4) + + info = self.env.get_metrics() + + if len(action_seq) == 0 and goal is None: + if action != 5: + sources = copy.deepcopy(self.conversation) + sources[0]["value"] = sources[0]["value"].replace( + '.', episode.instruction.instruction_text[:-1] + ) + cur_images = rgb_list[-1:] + if step_id == 0: + history_id = [] + else: + history_id = np.unique( + np.linspace(0, step_id - 1, self.num_history, dtype=np.int32) + ).tolist() + placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id) + sources[0]["value"] += f' These are your historical observations: {placeholder}.' + + history_id = sorted(history_id) + print('history_idddddddd', step_id, history_id) + input_images = [rgb_list[i] for i in history_id] + cur_images + input_img_id = 0 + else: + assert action == 5 + sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}] + input_images += [look_down_image] + # messages.append( + # {'role': 'assistant', 'content': [{'type': 'text', 'text': llm_outputs}]} # noqa: F405 + # ) + input_img_id = -1 + + prompt = random.choice(self.conjunctions) + DEFAULT_IMAGE_TOKEN + sources[0]["value"] += f" {prompt}." + print('sources', step_id, sources) + prompt_instruction = copy.deepcopy(sources[0]["value"]) + parts = split_and_clean(prompt_instruction) + + content = [] + for i in range(len(parts)): + if parts[i] == "": + content.append({"type": "image", "image": input_images[input_img_id]}) + input_img_id += 1 + else: + content.append({"type": "text", "text": parts[i]}) + + messages.append({'role': 'user', 'content': content}) + + print('step_id', step_id, 'messages:', messages) + + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + inputs = self.processor(text=[text], images=input_images, return_tensors="pt").to(self.model.device) + + with torch.no_grad(): + output_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False) + + llm_outputs = self.processor.tokenizer.decode( + output_ids[0][inputs.input_ids.shape[1] :], skip_special_tokens=True + ) + print('step_id:', step_id, 'output text:', llm_outputs) + + if bool(re.search(r'\d', llm_outputs)): + forward_action = 0 + coord = [int(c) for c in re.findall(r'\d+', llm_outputs)] + pixel_goal = [int(coord[1]), int(coord[0])] + + intrinsic_matrix = self.get_intrinsic_matrix( + self.config.habitat.simulator.agents.main_agent.sim_sensors.rgb_sensor + ) + goal = self.pixel_to_gps(pixel_goal, depth / 1000, intrinsic_matrix, tf_camera_to_episodic) + print('before', goal, depth.shape) + goal = (transformation_matrix @ np.array([-goal[1], 0, -goal[0], 1]))[:3] + + if not self.env._env.sim.pathfinder.is_navigable(np.array(goal)): + goal = np.array(self.env._env.sim.pathfinder.snap_point(np.array(goal))) + + # look down --> horizontal + self.env.step(4) + self.env.step(4) + + # Forking logic based on mode + if self.model_args.mode == 'system2': + action = agent.get_next_action(goal) + if action == 0: + goal = None + output_ids = None + action = 2 # random action + print('conduct a random action 2') + observations, _, done, _ = self.env.step(action) + step_id += 1 + messages = [] + continue + else: # dual-system logic + local_actions = [] + pixel_values = inputs.pixel_values + image_grid_thw = torch.cat([thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0) + + with torch.no_grad(): + traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw) + + # prepocess align with navdp + image_dp = ( + torch.tensor(np.array(look_down_image.resize((224, 224)))).to(torch.bfloat16) / 255 + ) + pix_goal_image = copy.copy(image_dp) + images_dp = torch.stack([pix_goal_image, image_dp]).unsqueeze(0).to(self.device) + depth_dp = look_down_depth.unsqueeze(-1).to(torch.bfloat16) + pix_goal_depth = copy.copy(depth_dp) + depths_dp = torch.stack([pix_goal_depth, depth_dp]).unsqueeze(0).to(self.device) + + with torch.no_grad(): + dp_actions = self.model.generate_traj( + traj_latents, images_dp, depths_dp, use_async=True + ) + + random_choice = np.random.choice(dp_actions.shape[0]) + if self.model_args.continuous_traj: + action_list = traj_to_actions(dp_actions) + if len(action_list) < 8: + action_list += [0] * (8 - len(action_list)) + else: + action_list = chunk_token(dp_actions[random_choice]) + + local_actions = action_list + if len(local_actions) >= 4: + local_actions = local_actions[:4] + action = local_actions[0] + if action == 0: + goal = None + output_ids = None + action = 2 # random action + print('conduct a random action 2') + observations, _, done, _ = self.env.step(action) + step_id += 1 + messages = [] + continue + + print('predicted goal', pixel_goal, goal, flush=True) + else: + action_seq = self.parse_actions(llm_outputs) + print('actions', action_seq, flush=True) + + if len(action_seq) != 0: + action = action_seq[0] + action_seq.pop(0) + elif goal is not None: + # Forking logic based on mode + if self.model_args.mode == 'system2': + action = agent.get_next_action(goal) + action = action.detach().cpu().numpy()[0] if isinstance(action, torch.Tensor) else action + action = action[0] if hasattr(action, "__len__") else action + else: # dual-system logic + if len(local_actions) == 0: + # navdp + local_actions = [] + image_dp = ( + torch.tensor(np.array(look_down_image.resize((224, 224)))).to(torch.bfloat16) / 255 + ) + + images_dp = torch.stack([pix_goal_image, image_dp]).unsqueeze(0).to(self.device) + depth_dp = look_down_depth.unsqueeze(-1).to(torch.bfloat16) + + depths_dp = torch.stack([pix_goal_depth, depth_dp]).unsqueeze(0).to(self.device) + with torch.no_grad(): + dp_actions = self.model.generate_traj( + traj_latents, images_dp, depths_dp, use_async=True + ) + + random_choice = np.random.choice(dp_actions.shape[0]) + if self.model_args.continuous_traj: + action_list = traj_to_actions(dp_actions) + if len(action_list) < 8: + action_list += [0] * (8 - len(action_list)) + else: + action_list = chunk_token(dp_actions[random_choice]) + print("first action_list", action_list) + + local_actions = action_list + if len(local_actions) >= 4: + local_actions = local_actions[:4] + # if len(local_actions) >= 2: + # local_actions = local_actions[:2] + + print("local_actions", local_actions) + + action = local_actions.pop(0) + # navdp + else: + action = local_actions.pop(0) + + forward_action += 1 + print('forward_action', forward_action, flush=True) + if forward_action > 8: + goal = None + output_ids = None + messages = [] + step_id += 1 + forward_action = 0 + local_actions = [] + continue + if action == 0: + goal = None + output_ids = None + messages = [] + step_id += 1 + forward_action = 0 + local_actions = [] + continue + else: + action = 0 + + if info['top_down_map'] is not None: + if save_dot: + save_raw_image = self.dot_matrix_two_dimensional( + save_raw_image, save_img=False, save_path=f'test_{step_id}.jpg', pixel_goal=pixel_goal + ) + if self.save_video: + frame = observations_to_image({'rgb': np.asarray(save_raw_image)}, info) + vis_frames.append(frame) + + print("step_id", step_id, "action", action) + + # refactor: core + if action == 5: + self.env.step(action) + observations, _, done, _ = self.env.step(action) + else: + observations, _, done, _ = self.env.step(action) + step_id += 1 + messages = [] + + # ---------- 3. End of episode ----------- + # Update result and write progress to the output_path/progress.json + + process_bar.update(1) + + # After the episode finishes, collect metrics: + metrics = self.env.get_metrics() + + sucs.append(metrics['success']) + spls.append(metrics['spl']) + oss.append(metrics['oracle_success']) + nes.append(metrics["distance_to_goal"]) + + print( + f"scene_episode {scene_id}_{episode_id:04d} success: {metrics['success']}, " + f"spl: {metrics['spl']}, os: {metrics['oracle_success']}, " + f"ne: {metrics['distance_to_goal']}" + ) + + # Write per-episode result.json entry (still per-rank) + result = { + "scene_id": scene_id, + "episode_id": episode_id, + "success": metrics["success"], + "spl": metrics["spl"], + "os": metrics['oracle_success'], + "ne": metrics["distance_to_goal"], + "steps": step_id, + "episode_instruction": episode_instruction, + } + os.makedirs(self.output_path, exist_ok=True) + with open(os.path.join(self.output_path, 'progress.json'), 'a') as f: + f.write(json.dumps(result) + "\n") + if self.save_video: + images_to_video( + vis_frames, + os.path.join(self.output_path, f'vis_{self.epoch}', f'{scene_id}'), + f'{episode_id:04d}', + fps=6, + quality=9, + ) + vis_frames.clear() + + self.env.close() + + return ( + torch.tensor(sucs).to(self.device), + torch.tensor(spls).to(self.device), + torch.tensor(oss).to(self.device), + torch.tensor(nes).to(self.device), + torch.tensor(len(sucs)).to(self.device), + ) + + def parse_actions(self, output): + action_patterns = '|'.join(re.escape(action) for action in self.actions2idx) + # import ipdb; ipdb.set_trace() + regex = re.compile(action_patterns) + matches = regex.findall(output) + actions = [self.actions2idx[match] for match in matches] + actions = itertools.chain.from_iterable(actions) + return list(actions) + + def preprocess_depth_image_v2( + self, depth_image, do_depth_scale=True, depth_scale=1000, target_height=None, target_width=None + ): + if target_height is None: + target_height = self.image_processor.crop_size['height'] # 384 + target_width = self.image_processor.crop_size['width'] # 384 + + resized_depth_image = depth_image.resize((target_width, target_height), Image.NEAREST) + + img = to_numpy_array(resized_depth_image) + if do_depth_scale: + img = img / depth_scale + + return img, (target_width, target_height) + + def get_intrinsic_matrix(self, sensor_cfg) -> np.ndarray: + width = sensor_cfg.width + height = sensor_cfg.height + fov = sensor_cfg.hfov + fx = (width / 2.0) / np.tan(np.deg2rad(fov / 2.0)) + fy = fx # Assuming square pixels (fx = fy) + cx = (width - 1.0) / 2.0 + cy = (height - 1.0) / 2.0 + + intrinsic_matrix = np.array( + [[fx, 0.0, cx, 0.0], [0.0, fy, cy, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + ) + return intrinsic_matrix + + def get_axis_align_matrix(self): + ma = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) + return ma + + def xyz_yaw_to_tf_matrix(self, xyz: np.ndarray, yaw: float) -> np.ndarray: + x, y, z = xyz + transformation_matrix = np.array( + [ + [np.cos(yaw), -np.sin(yaw), 0, x], + [np.sin(yaw), np.cos(yaw), 0, y], + [0, 0, 1, z], + [0, 0, 0, 1], + ] + ) + return transformation_matrix + + def xyz_pitch_to_tf_matrix(self, xyz: np.ndarray, pitch: float) -> np.ndarray: + """Converts a given position and pitch angle to a 4x4 transformation matrix. + + Args: + xyz (np.ndarray): A 3D vector representing the position. + pitch (float): The pitch angle in radians for y axis. + Returns: + np.ndarray: A 4x4 transformation matrix. + """ + + x, y, z = xyz + transformation_matrix = np.array( + [ + [np.cos(pitch), 0, np.sin(pitch), x], + [0, 1, 0, y], + [-np.sin(pitch), 0, np.cos(pitch), z], + [0, 0, 0, 1], + ] + ) + return transformation_matrix + + def xyz_yaw_pitch_to_tf_matrix(self, xyz: np.ndarray, yaw: float, pitch: float) -> np.ndarray: + """Converts a given position and yaw, pitch angles to a 4x4 transformation matrix. + + Args: + xyz (np.ndarray): A 3D vector representing the position. + yaw (float): The yaw angle in radians. + pitch (float): The pitch angle in radians for y axis. + Returns: + np.ndarray: A 4x4 transformation matrix. + """ + x, y, z = xyz + rot1 = self.xyz_yaw_to_tf_matrix(xyz, yaw)[:3, :3] + rot2 = self.xyz_pitch_to_tf_matrix(xyz, pitch)[:3, :3] + transformation_matrix = np.eye(4) + transformation_matrix[:3, :3] = rot1 @ rot2 + transformation_matrix[:3, 3] = xyz + return transformation_matrix + + def pixel_to_gps(self, pixel, depth, intrinsic, tf_camera_to_episodic): + ''' + Args: + pixel: (2,) - [u, v] pixel coordinates + depth: (H, W) - depth image where depth[v, u] gives depth in meters + intrinsic: (4, 4) - camera intrinsic matrix + tf_camera_to_episodic: (4, 4) - transformation from camera to episodic frame + Returns: + (x, y): (x, y) coordinates in the episodic frame + ''' + v, u = pixel + z = depth[v, u] + print("depthhhhhhhhhhhhhh", z) + + x = (u - intrinsic[0, 2]) * z / intrinsic[0, 0] + y = (v - intrinsic[1, 2]) * z / intrinsic[1, 1] + point_camera = np.array([x, y, z, 1.0]) + + # Transform to episodic frame + point_episodic = tf_camera_to_episodic @ point_camera + point_episodic = point_episodic[:3] / point_episodic[3] + + x = point_episodic[0] + y = point_episodic[1] + + return (x, y) # same as habitat gps + + def dot_matrix_two_dimensional( + self, + image_or_image_path, + save_path=None, + dots_size_w=8, + dots_size_h=8, + save_img=False, + font_path='fonts/arial.ttf', + pixel_goal=None, + ): + """ + takes an original image as input, save the processed image to save_path. Each dot is labeled with two-dimensional Cartesian coordinates (x,y). Suitable for single-image tasks. + control args: + 1. dots_size_w: the number of columns of the dots matrix + 2. dots_size_h: the number of rows of the dots matrix + """ + with open_image(image_or_image_path) as img: + if img.mode != 'RGB': + img = img.convert('RGB') + draw = ImageDraw.Draw(img, 'RGB') + + width, height = img.size + grid_size_w = dots_size_w + 1 + grid_size_h = dots_size_h + 1 + cell_width = width / grid_size_w + cell_height = height / grid_size_h + + font = ImageFont.truetype(font_path, width // 40) # Adjust font size if needed; default == width // 40 + + target_i = target_j = None + if pixel_goal is not None: + y_pixel, x_pixel = pixel_goal[0], pixel_goal[1] + # Validate pixel coordinates + if not (0 <= x_pixel < width and 0 <= y_pixel < height): + raise ValueError(f"pixel_goal {pixel_goal} exceeds image dimensions ({width}x{height})") + + # Convert to grid coordinates + target_i = round(x_pixel / cell_width) + target_j = round(y_pixel / cell_height) + + # Validate grid bounds + if not (1 <= target_i <= dots_size_w and 1 <= target_j <= dots_size_h): + raise ValueError( + f"pixel_goal {pixel_goal} maps to grid ({target_j},{target_i}), " + f"valid range is (1,1)-({dots_size_h},{dots_size_w})" + ) + + count = 0 + + for j in range(1, grid_size_h): + for i in range(1, grid_size_w): + x = int(i * cell_width) + y = int(j * cell_height) + + pixel_color = img.getpixel((x, y)) + # choose a more contrasting color from black and white + if pixel_color[0] + pixel_color[1] + pixel_color[2] >= 255 * 3 / 2: + opposite_color = (0, 0, 0) + else: + opposite_color = (255, 255, 255) + + if pixel_goal is not None and i == target_i and j == target_j: + opposite_color = (255, 0, 0) # Red for target + + circle_radius = width // 240 # Adjust dot size if needed; default == width // 240 + draw.ellipse( + [(x - circle_radius, y - circle_radius), (x + circle_radius, y + circle_radius)], + fill=opposite_color, + ) + + text_x, text_y = x + 3, y + count_w = count // dots_size_w + count_h = count % dots_size_w + label_str = f"({count_w+1},{count_h+1})" + draw.text((text_x, text_y), label_str, fill=opposite_color, font=font) + count += 1 + if save_img: + print(">>> dots overlaid image processed, stored in", save_path) + img.save(save_path) + return img diff --git a/internnav/env/utils/habitat_extensions/measures.py b/internnav/habitat_extensions/measures.py similarity index 100% rename from internnav/env/utils/habitat_extensions/measures.py rename to internnav/habitat_extensions/measures.py diff --git a/internnav/model/utils/utils.py b/internnav/model/utils/utils.py index 02330ea..692f13c 100755 --- a/internnav/model/utils/utils.py +++ b/internnav/model/utils/utils.py @@ -349,6 +349,9 @@ def get_action(diffusion_output, action_stats, cumsum=True): ndeltas = unnormalize_data(ndeltas, action_stats) if cumsum: + import torch + + torch.use_deterministic_algorithms(False) actions = torch.cumsum(ndeltas, dim=1) # This get the relative actions (not delta) from the diffusion output else: actions = ndeltas @@ -391,11 +394,9 @@ def unnormalize_data(ndata, stats): ndata_part = (ndata[:, :2] + 1) / 2 try: data = ndata_part * (stats['max'].to(device) - stats['min'].to(device)) + stats['min'].to(device) - except Exception as e: + except Exception: data = ndata_part * (stats.max.to(device) - stats.min.to(device)) + stats.min.to(device) - # if len(ndata.shape) == 3: - # data = torch.cat([data, ndata[:, 2:]], dim=1) return data @@ -440,7 +441,7 @@ def load_dataset(dataset_root_dir, split, logger=None, dataset_type='r2r'): item['c_reference_path'].append([path[0], -path[2], path[1]]) item['reference_path'] = item['c_reference_path'] del item['c_reference_path'] - + if dataset_type == 'kujiale': load_data[f'{str(item["trajectory_id"])}_{str(item["episode_id"])}'].append(item) else: @@ -448,4 +449,3 @@ def load_dataset(dataset_root_dir, split, logger=None, dataset_type='r2r'): if logger is not None: logger.info(f'Loaded data with a total of {len(load_data)} items from {split}') return load_data - diff --git a/internnav/projects/__init__.py b/internnav/projects/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/internnav/projects/dataloader/data_reviser.py b/internnav/projects/dataloader/data_reviser.py deleted file mode 100644 index d4841b4..0000000 --- a/internnav/projects/dataloader/data_reviser.py +++ /dev/null @@ -1,449 +0,0 @@ -fall_path_z_0_3 = [ - 70, - 121, - 146, - 156, - 172, - 326, - 349, - 372, - 394, - 415, - 434, - 469, - 531, - 550, - 580, - 626, - 674, - 700, - 768, - 808, - 823, - 835, - 854, - 859, - 958, - 1009, - 1058, - 1065, - 1093, - 1105, - 1142, - 1205, - 1238, - 1245, - 1263, - 1290, - 1295, - 1353, - 1400, - 1403, - 1455, - 1470, - 1530, - 1644, - 1645, - 1650, - 1734, - 1771, - 1848, - 1876, - 1880, - 1893, - 1925, - 1928, - 1957, - 1967, - 1995, - 2051, - 2061, - 2100, - 2101, - 2102, - 2156, - 2173, - 2186, - 2252, - 2253, - 2296, - 2335, - 2360, - 2399, - 2441, - 2485, - 2502, - 2508, - 2530, - 2591, - 2609, - 2622, - 2632, - 2651, - 2676, - 2744, - 2752, - 2809, - 2871, - 2911, - 2951, - 2967, - 2968, - 2981, - 2991, - 3023, - 3031, - 3032, - 3078, - 3093, - 3115, - 3145, - 3156, - 3160, - 3183, - 3194, - 3291, - 3304, - 3351, - 3528, - 3534, - 3576, - 3596, - 3605, - 3629, - 3656, - 3665, - 3689, - 3733, - 3749, - 3789, - 3833, - 3838, - 3859, - 3863, - 3868, - 3890, - 3978, - 3984, - 3993, - 4005, - 4022, - 4112, - 4122, - 4136, - 4214, - 4257, - 4264, - 4281, - 4311, - 4318, - 4356, - 4407, - 4460, - 4467, - 4533, - 4536, - 4551, - 4586, - 4656, - 4694, - 4698, - 4725, - 4800, - 4805, - 4807, - 4848, - 4867, - 4927, - 4949, - 5103, - 5170, - 5176, - 5228, - 5325, - 5327, - 5427, - 5443, - 5462, - 5529, - 5552, - 5625, - 5660, - 5690, - 5703, - 5753, - 5757, - 5817, - 5900, - 5928, - 5948, - 5955, - 6004, - 6109, - 6113, - 6120, - 6141, - 6181, - 6206, - 6221, - 6260, - 6283, - 6404, - 6422, - 6529, - 6608, - 6631, - 6660, - 6713, - 6731, - 6736, - 6749, - 6786, - 6800, - 6913, - 6916, - 6938, - 6971, - 6993, - 7021, - 7052, - 7145, - 7180, - 7202, - 7264, - 3477, - 5197, - 6372, - 4175, - 5929, - 7029, - 1924, - 2376, - 4877, - 6463, - 765, - 4415, - 5133, - 59, - 246, - 592, - 604, - 952, - 1185, - 1362, - 2680, - 3727, - 839, - 1444, - 274, - 3265, - 3592, - 4514, - 5847, - 6005, - 6599, - 2461, - 3703, - 219, - 1731, - 1822, - 6055, - 6142, - 7289, - 5280, - 41, - 1982, - 2108, - 2247, - 2554, - 3853, - 4818, - 6768, - 6794, - 7003, - 7033, - 2733, - 4860, - 606, - 1200, - 1083, - 6039, - 651, - 797, - 1014, - 4006, - 5454, - 6826, - 6899, - 6933, - 6373, - 1415, - 1418, - 2457, - 4691, - 6342, - 621, - 602, - 946, - 5431, - 6163, - 6208, - 890, - 1668, - 2031, - 4161, - 4826, - 6183, - 1592, - 3645, - 4376, - 109, - 369, - 743, - 1432, - 2147, - 2190, - 3946, - 5720, - 6680, - 2994, - 3039, - 3781, - 4754, - 4920, - 6774, - 6942, - 2950, - 5624, - 3960, - 4890, - 4994, - 6036, - 2306, -] - -skip_list = [ -] - -fall_path_custom = { - 6558: [-1, 0, 0], - 454: [0.42, 0.9, 0], - 490: [0.97, 0.25, 0], - 910: [-0.4, 0, 0], - 1253: [-0.4, 0, 0], - 1834: [0, -0.5, 0.3], - 2004: [0.5, 0.5, 0], - 2249: [1, -1, 0], - 2382: [1, -0.5, 0], - 2468: [0.2, 0, 0], - 2498: [-0.2, -0.5, 0], - 2523: [1, 0, 0], - 2529: [1, 0.3, 0], - 2618: [-0.5, 0.2, 0.3], - 2688: [0, -1, 0], - 2768: [-0.86, 0.52, 0], - 3084: [0.88, -0.47, 0], - 3136: [1.0, 0, 0], - 3165: [0, 0, 0.8], - 3231: [0, -0.5, 0.3], - 3277: [0, 1, 0.3], - 3414: [0.5, 0, 0.3], - 3464: [0.7, -1, 0], - 3468: [-0.5, 0, 0], - 3686: [0.2, 0.2, 0], - 4073: [-0.24, 0.5, 0], - 4243: [0.2, 0, 0], - 4305: [0, -0.2, 0], - 4564: [-0.5, 0, 0], - 5252: [0.2, 0, 0.3], - 5328: [0, 0.5, 0], - 5401: [-1, -0.2, 0.0], - 5461: [-1.0, 0, 0.3], - 5560: [0, -0.5, 0.0], - 5609: [0.5, 0, 0.3], - 5930: [0.5, 0, 0], - 6262: [-0.5, 0, 0], - 6640: [0, -0.5, 0], - 6840: [0, -0.5, 0], - 6914: [0, -0.5, 0], - 7108: [0.5, 0, 0], - 7229: [0, -0.5, 0], - 7246: [0, 0.2, 0], - 7273: [0.5, 0, 0], - 338: [1, 1.2, 0.3], - 435: [0, 1, 0], - 2965: [0, 1, 0], - 3258: [0, 1, 0], - 1483: [0.5, 0, 0.3], - 5256: [0.8, 0, 0], - 1234: [0.2, -0.2, 0], - 1954: [0.2, -0.2, 0], - 2322: [0.2, 1, 0], - 6390: [0.2, 1, 0], - 6672: [0, 0.5, 0], - 5372: [0.5, 0, 0], - 2357: [0.3, -0.3, 0], - 95: [0.2, -0.5, 0], - 2778: [0.4, -0.5, 0], - 7281: [0.2, -0.5, 0], - 332: [-0.3, 0, 0], - 648: [-0.3, 0, 0], - 2716: [-0.2, 0, 0], - 2896: [0.2, 0.2, 0], - 3028: [0.2, 0.2, 0], - 3754: [0, 0.2, 0], - 4463: [-0.1, 0, 0], - 4615: [-0.1, 0, 0], - 5773: [-0.1, 0, 0], - 6783: [0.5, 0, 0], - 801: [0.5, 0, 0], - 5661: [0.5, 0, 0], - 675: [0, 0.5, 0], - 6526: [-0.5, 0, 0], - 7285: [-0.5, 0, 0], - 622: [0, -0.3, 0.3], - 4746: [0, -0.3, 0.3], - 1623: [0, -0.5, 0], - 5574: [0, 0.5, 0], - 1847: [0, 1.2, 0], - 2470: [0, 1.2, 0], - 2240: [-1, 0, 0], - 6694: [0, 0.2, 0], - 2180: [0.5, 0, 0], - 138: [0.5, -0.1, 0.1], - 175: [0.2, 0, 0], - 1899: [0.2, 0.2, 0], - 3858: [0, -2, 0.1], - 3952: [0.5, -0.1, 0.1], - 4156: [0.5, -0.1, 0.1], - 6077: [0.2, 0.2, 0], - 6875: [-0.2, -0.2, 0], - 7007: [-0.2, -0.2, 0], - 498: [0.5, 0, 0], - 3406: [0.5, 0, 0], - 3627: [-0.2, -0.5, 0], - 4239: [-0.3, 0, 0], - 412: [0, -0.1, 0], - 3347: [0, -0.1, 0], - 1944: [-0.2, -0.2, 0], - 2668: [-0.2, -0.2, 0], - 2749: [-0.5, 0, 0], - 1182: [0, -0.6, 0], -} - - -def revise_one_data(origin): - trajectory_id = origin['trajectory_id'] - if trajectory_id in fall_path_z_0_3: - amend_offset = [0, 0, 0.3] - elif trajectory_id in fall_path_custom: - amend_offset = fall_path_custom[trajectory_id] - else: - return origin - origin['start_position'][0] = origin['start_position'][0] + amend_offset[0] - origin['start_position'][1] = origin['start_position'][1] + amend_offset[1] - origin['start_position'][2] = origin['start_position'][2] + amend_offset[2] - origin['reference_path'][0][0] = origin['reference_path'][0][0] + amend_offset[0] - origin['reference_path'][0][1] = origin['reference_path'][0][1] + amend_offset[1] - origin['reference_path'][0][2] = origin['reference_path'][0][2] + amend_offset[2] - return origin diff --git a/internnav/utils/comm_utils/server.py b/internnav/utils/comm_utils/server.py index 2d3bf27..fbe143b 100644 --- a/internnav/utils/comm_utils/server.py +++ b/internnav/utils/comm_utils/server.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import base64 +import multiprocessing import pickle from typing import Dict @@ -77,3 +78,41 @@ def run(self, reload=False): reload=reload, reload_dirs=['./internnav/agent/', './internnav/model/'], ) + + +def start_server(host='localhost', port=8087, dist=False): + """ + start a server in the backgrouond process + + Args: + host + port + + Returns: + The rank of the process group + -1, if not part of the group + + """ + ctx = multiprocessing.get_context("spawn") + p = ctx.Process(target=_run_server if not dist else _run_server_dist, args=(host, port)) + p.daemon = True + p.start() + print(f"Server started on {host}:{port} (pid={p.pid})") + return p + + +def _run_server_dist(host='localhost', port=8087): + import torch + + from internnav.utils.dist import get_rank + + device_idx = get_rank() + torch.cuda.set_device(device_idx) + print(f"Server using GPU {device_idx}") + server = AgentServer(host, port) + server.run() + + +def _run_server(host='localhost', port=8087): + server = AgentServer(host, port) + server.run() diff --git a/internnav/utils/common_log_util.py b/internnav/utils/common_log_util.py index b04830e..d012572 100644 --- a/internnav/utils/common_log_util.py +++ b/internnav/utils/common_log_util.py @@ -20,7 +20,7 @@ def init(task_name='default'): NAME = task_name log_dir = f'{PROJECT_ROOT_PATH}/logs/{task_name}/common' if not os.path.exists(log_dir): - os.makedirs(log_dir) + os.makedirs(log_dir, exist_ok=True) file_name = f"log_{datetime.now().strftime('%Y%m%d%H%M%S')}_{os.getpid()}.log" file_handler = logging.FileHandler(f'{log_dir}/{file_name}') file_handler.setLevel(logging.INFO) diff --git a/internnav/utils/dist.py b/internnav/utils/dist.py index 7994e25..0b7ac64 100644 --- a/internnav/utils/dist.py +++ b/internnav/utils/dist.py @@ -1,13 +1,13 @@ -import os -import time import builtins import datetime +import os import subprocess +import time +from collections import defaultdict, deque import torch import torch.distributed as dist -from collections import defaultdict, deque class SmoothedValue(object): def __init__(self, window_size=20, fmt=None): @@ -60,11 +60,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) class MetricLogger(object): @@ -86,15 +83,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -113,14 +107,7 @@ def log_every(self, iterable, print_freq, header=None): iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' - log_msg = [ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ] + log_msg = [header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}'] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}') log_msg = self.delimiter.join(log_msg) @@ -133,22 +120,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {} ({:.4f} s / it)'.format( - header, total_time_str, total_time / len(iterable))) + print('{} Total time: {} ({:.4f} s / it)'.format(header, total_time_str, total_time / len(iterable))) def setup_for_distributed(is_master): @@ -197,57 +190,69 @@ def save_on_master(*args, **kwargs): torch.save(*args, **kwargs) -def init_distributed_mode(args): - if 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) - args.world_size = int(os.environ['SLURM_NTASKS']) - +def init_distributed_mode(dist_url="env://", port=29529, backend="nccl", timeout_hours=2): + # SLURM path: derive env then fall back to env:// + if "SLURM_PROCID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) num_gpus = torch.cuda.device_count() - args.gpu = args.rank % num_gpus - args.local_rank = args.gpu - - node_list = os.environ['SLURM_NODELIST'] - print(f'Node list: {node_list}') - addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') - - os.environ['MASTER_PORT'] = str(getattr(args, 'port', '29529')) - os.environ['MASTER_ADDR'] = addr - os.environ['WORLD_SIZE'] = str(args.world_size) - os.environ['LOCAL_RANK'] = str(args.gpu) - os.environ['RANK'] = str(args.rank) + local_rank = rank % max(1, num_gpus) + + # pick first node as master + nodelist = os.environ["SLURM_NODELIST"] + print(f'Node list: {nodelist}') + master_addr = subprocess.getoutput(f"scontrol show hostname {nodelist} | head -n1") + + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(local_rank) + + # Fast-path: torchrun provides these elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: - args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - args.local_rank = args.gpu + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + elif "RANK" in os.environ: + # fallback: assume per-node GPU count n + num_gpus = torch.cuda.device_count() + local_rank = rank % max(1, num_gpus) + else: + local_rank = 0 + else: - print('Not using distributed mode') - setup_for_distributed(is_master=True) # hack - args.distributed = False - return - - args.distributed = True - - torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}, gpu {}'.format(args.rank, args.dist_url, args.gpu), flush=True) - dist.init_process_group(backend=args.dist_backend, - init_method=args.dist_url, - world_size=args.world_size, - rank=args.rank, - timeout=datetime.timedelta(0, 7200)) + print("Not using distributed mode") + setup_for_distributed(is_master=True) + return 0 + + import socket + + print(f"Rank {os.getenv('RANK')} / {os.getenv('WORLD_SIZE')} on {socket.gethostname()}:{os.getenv('MASTER_PORT')}") + print('| distributed init (rank {}): {}, gpu {}'.format(rank, dist_url, local_rank), flush=True) + + # Device selection must happen before NCCL init + torch.cuda.set_device(local_rank) + + dist.init_process_group( + backend=backend, init_method=dist_url, world_size=world_size, rank=rank, timeout=datetime.timedelta(0, 7200) + ) dist.barrier() - setup_for_distributed(args.rank == 0) + setup_for_distributed(dist.get_rank() == 0) + return local_rank + def save_model(args, epoch, model_without_ddp, optimizer, checkpoint_path): to_save = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'epoch': epoch, - 'args': args, - } + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'args': args, + } save_on_master(to_save, checkpoint_path) + def all_reduce_mean(x): world_size = get_world_size() if world_size > 1: @@ -257,11 +262,16 @@ def all_reduce_mean(x): return x_reduce.item() else: return x - + + def fsdp_auto_wrap_policy(model, transformer_layer_names): import functools - from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + from torch.distributed.fsdp.wrap import ( + _or_policy, + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, + ) def lambda_policy_fn(module): if ( @@ -274,9 +284,8 @@ def lambda_policy_fn(module): lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=set(transformer_layer_names) + transformer_auto_wrap_policy, transformer_layer_cls=set(transformer_layer_names) ) auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) - return auto_wrap_policy \ No newline at end of file + return auto_wrap_policy diff --git a/internnav/utils/progress_log_multi_util.py b/internnav/utils/progress_log_multi_util.py index c7c60b1..d2f8d82 100644 --- a/internnav/utils/progress_log_multi_util.py +++ b/internnav/utils/progress_log_multi_util.py @@ -61,13 +61,13 @@ def init(dataset_name, path_count): global INITED PROGRESS = ProgressInfo(dataset_name, path_count) log_dir = f'{PROJECT_ROOT_PATH}/logs/{get_task_name()}/progress/' - if not os.path.exists(log_dir): - os.makedirs(log_dir) + os.makedirs(log_dir, exist_ok=True) file_handler = logging.FileHandler(f'{log_dir}/{dataset_name}.log') file_handler.setLevel(logging.INFO) formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s') file_handler.setFormatter(formatter) progress_logger_multi.addHandler(file_handler) + progress_logger_multi.disabled = False INITED = True diff --git a/scripts/eval/bash/eval_dual_system.sh b/scripts/eval/bash/eval_dual_system.sh index ef4be1e..cb52d14 100755 --- a/scripts/eval/bash/eval_dual_system.sh +++ b/scripts/eval/bash/eval_dual_system.sh @@ -1,10 +1,7 @@ -export MAGNUM_LOG=quiet HABITAT_SIM_LOG=quiet -export NCCL_SOCKET_IFNAME=bond0 -export NCCL_IB_HCA=mlx5_2,mlx5_3,mlx5_4,mlx5_5 - MID_RUN_NAME="InternVLA-N1" +CONFIG="scripts/eval/configs/habitat_dual_system_cfg.py" -srun -p efm_t \ +srun -p \ --gres=gpu:8 \ --ntasks=8 \ -x HOST-10-140-66-68,HOST-10-140-66-182,HOST-10-140-66-181 \ @@ -12,8 +9,6 @@ srun -p efm_t \ --ntasks-per-node=8 \ --cpus-per-task=16 \ --kill-on-bad-exit=1 \ - python scripts/eval/eval_habitat.py \ - --model_path checkpoints/${MID_RUN_NAME} \ - --predict_step_nums 32 \ - --continuous_traj \ - --output_path results/$MID_RUN_NAME/val_unseen_32traj_8steps \ + python scripts/eval/eval.py \ + --config $CONFIG \ + > logs/${MID_RUN_NAME}_log.txt 2>&1 diff --git a/scripts/eval/bash/eval_system2.sh b/scripts/eval/bash/eval_system2.sh index bedfbb3..409b099 100755 --- a/scripts/eval/bash/eval_system2.sh +++ b/scripts/eval/bash/eval_system2.sh @@ -1,11 +1,7 @@ -export MAGNUM_LOG=quiet HABITAT_SIM_LOG=quiet -export NCCL_SOCKET_IFNAME=bond0 -export NCCL_IB_HCA=mlx5_2,mlx5_3,mlx5_4,mlx5_5 +MID_RUN_NAME="InternVLA-N1" +CONFIG="scripts/eval/configs/habitat_s2_cfg.py" - -MID_RUN_NAME="vln_one_stage_with_qa_bs256_backup-checkpoint-21000" - -srun -p efm_t \ +srun -p \ --gres=gpu:8 \ --ntasks=8 \ -x HOST-10-140-66-68,HOST-10-140-66-182,HOST-10-140-66-181 \ @@ -13,7 +9,6 @@ srun -p efm_t \ --ntasks-per-node=8 \ --cpus-per-task=16 \ --kill-on-bad-exit=1 \ - python scripts/eval/eval_habitat.py \ - --model_path /path/to/${MID_RUN_NAME} \ - --mode system2 \ - --output_path results/$MID_RUN_NAME/val_unseen \ + python scripts/eval/eval.py \ + --config $CONFIG \ + > logs/${MID_RUN_NAME}_log.txt 2>&1 diff --git a/scripts/eval/bash/eval_vln_distributed.sh b/scripts/eval/bash/eval_vln_distributed.sh new file mode 100644 index 0000000..ae2060e --- /dev/null +++ b/scripts/eval/bash/eval_vln_distributed.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# ---------------------------------------------------------------------------- +# USAGE INSTRUCTIONS: +# +# This script runs a distributed evaluation for Habitat or Internutopia. +# You can choose from the following modes: +# +# 1. **habitat**: Runs evaluation using the Habitat environment. +# 2. **internutopia**: Runs evaluation using the Internutopia environment. +# 3. **internutopia_vec_env**: Runs distributed evaluation with Ray in vectorized environments. +# +# The script automatically activates the appropriate Conda environment based on the mode. +# You can specify a custom configuration file using the `--config` argument. +# +# Example usage: +# ./scripts/eval/bash/eval_vln_distributed.sh habitat --config +# ./ internutopia --config +# ./ internutopia_vec_env --config +# +# Make sure required Conda environments (habitat, internutopia) and Ray (for internutopia_vec_env) are set up. +# ---------------------------------------------------------------------------- + +# Activate conda (update to your local path) +source /root/miniconda3/etc/profile.d/conda.sh + +mode="$1" +shift # remove first argument so only extra args left (--config ...) + +CONFIG=scripts/eval/configs/h1_internvla_n1_async_cfg.py +while [[ $# -gt 0 ]]; do + case $1 in + --config) + CONFIG="$2" + shift 2 + ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +case "$mode" in + # start to evaluate habitat in dlc + habitat) + echo "[run.sh] Starting HABITAT evaluation..." + + conda activate habitat + + python scripts/eval/eval.py \ + --config $CONFIG + + ;; + internutopia) + echo "[run.sh] Starting INTERNUTOPIA evaluation..." + + conda activate internutopia + + python scripts/eval/eval.py \ + --config $CONFIG + + ;; + internutopia_vec_env) + echo "[run.sh] Starting INTERNUTOPIA evaluation..." + + conda activate internutopia + + # -------- parse remaining arguments (e.g., --config xxx) -------- + while [[ $# -gt 0 ]]; do + case $1 in + --config) + CONFIG="$2" + shift 2 + ;; + *) + echo "Unknown parameter: $1" + exit 1 + ;; + esac + done + # ---------------------------------------------------------------- + + if [ "$RANK" -eq 0 ]; then + echo "[run.sh] Starting Ray head..." + RAY_max_direct_call_object_size=104857600 \ + ray start --head --port=6379 + + sleep 20s + + echo "[run.sh] Exec start_eval.sh..." + bash scripts/eval/bash/start_eval.sh + + sleep inf + else + echo "[run.sh] Starting Ray worker..." + RAY_max_direct_call_object_size=104857600 \ + ray start --address=${MASTER_ADDR}:6379 + + sleep inf + fi + ;; + *) + echo "Usage: $0 {habitat|internutopia|internutopia_vec_env} [--config xxx]" + exit 1 + ;; +esac diff --git a/scripts/eval/bash/start_aliyun_dlc.sh b/scripts/eval/bash/start_aliyun_dlc.sh deleted file mode 100755 index 14d5e90..0000000 --- a/scripts/eval/bash/start_aliyun_dlc.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash -source /root/miniconda3/etc/profile.d/conda.sh -conda activate internutopia - -while [[ $# -gt 0 ]]; do - case $1 in - --config) - CONFIG="$2" - shift 2 - ;; - *) - echo "Unknown parameter: $1" - exit 1 - ;; - esac -done - -if [ "$RANK" -eq 0 ]; then - RAY_max_direct_call_object_size=104857600 ray start --head --port=6379 - sleep 20s - bash scripts/eval/start_eval_iros.sh - sleep inf -else - RAY_max_direct_call_object_size=104857600 ray start --address=${MASTER_ADDR}:6379 - sleep inf -fi diff --git a/scripts/eval/bash/start_eval.sh b/scripts/eval/bash/start_eval.sh index ce41df1..a7b3639 100755 --- a/scripts/eval/bash/start_eval.sh +++ b/scripts/eval/bash/start_eval.sh @@ -3,7 +3,7 @@ source /root/miniconda3/etc/profile.d/conda.sh conda activate internutopia -CONFIG=scripts/eval/configs/h1_cma_cfg.py +CONFIG=scripts/eval/configs/h1_rdp_cfg.py while [[ $# -gt 0 ]]; do case $1 in diff --git a/scripts/eval/bash/torchrun_eval.sh b/scripts/eval/bash/torchrun_eval.sh new file mode 100644 index 0000000..5747eea --- /dev/null +++ b/scripts/eval/bash/torchrun_eval.sh @@ -0,0 +1,28 @@ +# use to run distributed eval with multi gpus + +CONFIG=scripts/eval/configs/h1_internvla_n1_async_cfg.py + +while [[ $# -gt 0 ]]; do + case $1 in + --config) + CONFIG="$2" + shift 2 + ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +# Extract the prefix from the config filename +CONFIG_BASENAME=$(basename "$CONFIG" .py) +CONFIG_PREFIX=$(echo "$CONFIG_BASENAME" | sed 's/_cfg$//') +EVAL_LOG="logs/${CONFIG_PREFIX}_eval.log" + +torchrun \ + --nproc_per_node=1 \ + --master_port=2333 \ + scripts/eval/eval.py \ + --config $CONFIG \ + > $EVAL_LOG 2>&1 diff --git a/scripts/eval/configs/h1_cma_cfg.py b/scripts/eval/configs/h1_cma_cfg.py index 6b27ee9..6c509d4 100644 --- a/scripts/eval/configs/h1_cma_cfg.py +++ b/scripts/eval/configs/h1_cma_cfg.py @@ -24,9 +24,9 @@ task=TaskCfg( task_name='cma_plus_eval', task_settings={ - 'env_num': 2, + 'env_num': 1, 'use_distributed': False, - 'proc_num': 2, + 'proc_num': 8, }, scene=SceneCfg( scene_type='mp3d', @@ -44,8 +44,13 @@ dataset_settings={ 'base_data_dir': 'data/vln_pe/raw_data/r2r', 'split_data_types': ['val_unseen', 'val_seen'], - 'filter_stairs': False, + 'filter_stairs': True, }, ), - eval_settings={'save_to_json': False, 'vis_output': True}, + eval_type='vln_distributed', + eval_settings={ + 'save_to_json': True, + 'vis_output': True, + 'use_agent_server': True, + }, ) diff --git a/scripts/eval/configs/h1_internvla_n1_async_cfg.py b/scripts/eval/configs/h1_internvla_n1_async_cfg.py index 9ee5b83..36f1012 100644 --- a/scripts/eval/configs/h1_internvla_n1_async_cfg.py +++ b/scripts/eval/configs/h1_internvla_n1_async_cfg.py @@ -27,7 +27,7 @@ 'num_frames': 32, 'num_history': 8, 'num_future_steps': 4, - 'device': 'cuda:1', + 'device': 'cuda:0', 'predict_step_nums': 32, 'continuous_traj': True, 'infer_mode': 'partial_async', # You can choose "sync" or "partial_async", but for this model, "partial_async" is better. @@ -49,7 +49,7 @@ 'env_num': 1, 'use_distributed': False, # If the others setting in task_settings, please set use_distributed = False. 'proc_num': 1, - # 'max_step': 1000, #If use flash mode,default 1000; descrete mode, set 50000 + 'max_step': 1000, # If use flash mode,default 1000; descrete mode, set 50000 }, scene=SceneCfg( scene_type='mp3d', @@ -66,10 +66,16 @@ dataset_type="mp3d", dataset_settings={ 'base_data_dir': 'data/vln_pe/raw_data/r2r', - 'split_data_types': ['val_seen', 'val_unseen'], # 'val_seen' - 'filter_stairs': False, # For iros challenge, this is False; For results in the paper, this is True. + 'split_data_types': ['val_unseen'], # 'val_seen' + 'filter_stairs': True, # For iros challenge, this is False; For results in the paper, this is True. # 'selected_scans': ['zsNo4HB9uLZ'], # 'selected_scans': ['8194nk5LbLH', 'pLe4wQe7qrG'], }, ), + eval_type='vln_distributed', + eval_settings={ + 'save_to_json': True, + 'vis_output': False, + 'use_agent_server': False, # If use_agent_server=True, please start the agent server first. + }, ) diff --git a/scripts/eval/configs/h1_internvla_n1_cfg.py b/scripts/eval/configs/h1_internvla_n1_cfg.py deleted file mode 100644 index 90a801c..0000000 --- a/scripts/eval/configs/h1_internvla_n1_cfg.py +++ /dev/null @@ -1,72 +0,0 @@ -from internnav.configs.agent import AgentCfg -from internnav.configs.evaluator import ( - EnvCfg, - EvalCfg, - EvalDatasetCfg, - SceneCfg, - TaskCfg, -) - -eval_cfg = EvalCfg( - agent=AgentCfg( - server_port=8087, - model_name='internvla_n1', - ckpt_path='', - model_settings={ - 'env_num': 1, - 'sim_num': 1, - 'model_path': "checkpoints/InternVLA-N1", - 'camera_intrinsic': [[585.0, 0.0, 320.0], [0.0, 585.0, 240.0], [0.0, 0.0, 1.0]], - 'width': 640, - 'height': 480, - 'hfov': 79, - 'resize_w': 384, - 'resize_h': 384, - 'max_new_tokens': 1024, - 'num_frames': 32, - 'num_history': 8, - 'num_future_steps': 4, - 'device': 'cuda:0', - 'predict_step_nums': 32, - 'continuous_traj': True, - # debug - 'vis_debug': True, # If vis_debug=True, you can get visualization results - 'vis_debug_path': './logs/test/vis_debug', - }, - ), - env=EnvCfg( - env_type='internutopia', - env_settings={ - 'use_fabric': False, # Please set use_fabric=False due to the render delay; - 'headless': True, - }, - ), - task=TaskCfg( - task_name='test', - task_settings={ - 'env_num': 1, - 'use_distributed': False, # If the others setting in task_settings, please set use_distributed = False. - 'proc_num': 1, - }, - scene=SceneCfg( - scene_type='mp3d', - scene_data_dir='data/scene_data/mp3d_pe', - ), - robot_name='h1', - robot_flash=True, # If robot_flash is True, the mode is flash (set world_pose directly); else you choose physical mode. - robot_usd_path='data/Embodiments/vln-pe/h1/h1_internvla.usd', - camera_resolution=[640, 480], # (W,H) - camera_prim_path='torso_link/h1_1_25_down_30', - one_step_stand_still=True, # For dual-system, please keep this param True. - ), - dataset=EvalDatasetCfg( - dataset_type="mp3d", - dataset_settings={ - 'base_data_dir': 'data/vln_pe/raw_data/r2r', - 'split_data_types': ['val_unseen'], # 'val_seen' - 'filter_stairs': False, - # 'selected_scans': ['zsNo4HB9uLZ'], - # 'selected_scans': ['8194nk5LbLH', 'pLe4wQe7qrG'], - }, - ), -) diff --git a/scripts/eval/configs/h1_rdp_cfg.py b/scripts/eval/configs/h1_rdp_cfg.py index a6380c8..ef5ff19 100644 --- a/scripts/eval/configs/h1_rdp_cfg.py +++ b/scripts/eval/configs/h1_rdp_cfg.py @@ -26,7 +26,7 @@ task_settings={ 'env_num': 2, 'use_distributed': True, - 'proc_num': 1, + 'proc_num': 4, }, scene=SceneCfg( scene_type='mp3d', @@ -41,8 +41,14 @@ dataset_type="mp3d", dataset_settings={ 'base_data_dir': 'data/vln_pe/raw_data/r2r', - 'split_data_types': ['val_unseen', 'val_seen'], - 'filter_stairs': False, + 'split_data_types': ['val_unseen'], + 'filter_stairs': True, }, ), + eval_type='vln_distributed', + eval_settings={ + 'save_to_json': True, + 'vis_output': False, + 'use_agent_server': True, + }, ) diff --git a/scripts/eval/configs/h1_seq2seq_cfg.py b/scripts/eval/configs/h1_seq2seq_cfg.py index 2934e8e..b8696be 100644 --- a/scripts/eval/configs/h1_seq2seq_cfg.py +++ b/scripts/eval/configs/h1_seq2seq_cfg.py @@ -22,11 +22,11 @@ }, ), task=TaskCfg( - task_name='seq2seq_eval', + task_name='seq_eval', task_settings={ - 'env_num': 2, - 'use_distributed': True, - 'proc_num': 1, + 'env_num': 1, + 'use_distributed': False, + 'proc_num': 8, }, scene=SceneCfg( scene_type='mp3d', @@ -36,13 +36,21 @@ robot_usd_path='data/Embodiments/vln-pe/h1/h1_vln_pointcloud.usd', camera_resolution=[256, 256], # (W,H) camera_prim_path='torso_link/h1_pano_camera_0', + vlnce=False, # vlnpe by default + obstacle_detection=False, # whether allow flash across obstacle ), dataset=EvalDatasetCfg( dataset_type="mp3d", dataset_settings={ 'base_data_dir': 'data/vln_pe/raw_data/r2r', 'split_data_types': ['val_unseen', 'val_seen'], - 'filter_stairs': False, + 'filter_stairs': True, }, ), + eval_type='vln_distributed', + eval_settings={ + 'save_to_json': True, + 'vis_output': True, + 'use_agent_server': True, + }, ) diff --git a/scripts/eval/configs/habitat_dual_system_cfg.py b/scripts/eval/configs/habitat_dual_system_cfg.py new file mode 100644 index 0000000..7604804 --- /dev/null +++ b/scripts/eval/configs/habitat_dual_system_cfg.py @@ -0,0 +1,38 @@ +from internnav.configs.agent import AgentCfg +from internnav.configs.evaluator import EnvCfg, EvalCfg + +eval_cfg = EvalCfg( + agent=AgentCfg( + model_name='internvla_n1', + model_settings={ + "mode": "dual_system", # inference mode: dual_system or system2 + "model_path": "checkpoints/InternVLA-N1", # path to model checkpoint + "num_future_steps": 4, # number of future steps for prediction + "num_frames": 32, # number of frames used in evaluation + "num_history": 8, + "resize_w": 384, # image resize width + "resize_h": 384, # image resize height + "predict_step_nums": 32, # number of steps to predict + "continuous_traj": True, # whether to use continuous trajectory + "max_new_tokens": 1024, # maximum number of tokens for generation + }, + ), + env=EnvCfg( + env_type='habitat', + env_settings={ + # habitat sim specifications - agent, sensors, tasks, measures etc. are defined in the habitat config file + 'config_path': 'scripts/eval/configs/vln_r2r.yaml', + }, + ), + eval_type='habitat_vln', + eval_settings={ + # all current parse args + "output_path": "./logs/habitat/test_dual_system", # output directory for logs/results + "save_video": False, # whether to save videos + "epoch": 0, # epoch number for logging + "max_steps_per_episode": 500, # maximum steps per episode + # distributed settings + "port": "2333", # communication port + "dist_url": "env://", # url for distributed setup + }, +) diff --git a/scripts/eval/configs/habitat_r2r_pix.yaml b/scripts/eval/configs/habitat_r2r_pix.yaml deleted file mode 100644 index 454cdef..0000000 --- a/scripts/eval/configs/habitat_r2r_pix.yaml +++ /dev/null @@ -1,83 +0,0 @@ -# @package _global_ - -defaults: - - /habitat: habitat_config_base - - /habitat/task: vln_r2r - - /habitat/simulator/agents@habitat.simulator.agents.main_agent: rgbd_agent - # - /habitat/simulator/sensor_setups@habitat.simulator.agents.main_agent: rgbd_agent - - /habitat/dataset/vln: mp3d_r2r - - /habitat/task/lab_sensors: - - gps_sensor - - compass_sensor - - _self_ - -habitat: - environment: - max_episode_steps: 10000 - iterator_options: - max_scene_repeat_steps: 50000 - shuffle: False - simulator: - agents: - main_agent: - sim_sensors: - rgb_sensor: - width: 640 - height: 480 - hfov: 79 - # hfov: 69 - # position: [0, 0.6, 0] - depth_sensor: - width: 640 - height: 480 - hfov: 79 - # hfov: 69 - min_depth: 0.0 - max_depth: 10.0 - # position: [0, 0.6, 0] - forward_step_size: 0.25 - turn_angle: 15 - tilt_angle: 15 - action_space_config: "v1" - habitat_sim_v0: - gpu_device_id: 0 - task: - measurements: - distance_to_goal: - type: DistanceToGoal - distance_to: POINT - success: - type: Success - success_distance: 3.0 - spl: - type: SPL - oracle_success: - type: OracleSuccess - # success_distance: 3.0 - oracle_navigation_error: - type: OracleNavigationError - actions: - stop: - type: StopAction - agent_index: 0 - move_forward: - type: MoveForwardAction - agent_index: 0 - turn_left: - type: TurnLeftAction - agent_index: 0 - turn_right: - type: TurnRightAction - agent_index: 0 - look_up: - type: LookUpAction - agent_index: 0 - look_down: - type: LookDownAction - agent_index: 0 - - dataset: - type: R2RVLN-v1 - split: val_seen - scenes_dir: data/scene_data/mp3d_ce - data_path: data/datasets/vln/mp3d/r2r/v1/{split}/{split}.json.gz diff --git a/scripts/eval/configs/habitat_s2_cfg.py b/scripts/eval/configs/habitat_s2_cfg.py new file mode 100644 index 0000000..3debe85 --- /dev/null +++ b/scripts/eval/configs/habitat_s2_cfg.py @@ -0,0 +1,38 @@ +from internnav.configs.agent import AgentCfg +from internnav.configs.evaluator import EnvCfg, EvalCfg + +eval_cfg = EvalCfg( + agent=AgentCfg( + model_name='internvla_n1', + model_settings={ + "mode": "system2", # inference mode: dual_system or system2 + "model_path": "checkpoints/", # path to model checkpoint + "num_future_steps": 4, # number of future steps for prediction + "num_frames": 32, # number of frames used in evaluation + "num_history": 8, + "resize_w": 384, # image resize width + "resize_h": 384, # image resize height + "predict_step_nums": 32, # number of steps to predict + "continuous_traj": True, # whether to use continuous trajectory + "max_new_tokens": 1024, # maximum number of tokens for generation + }, + ), + env=EnvCfg( + env_type='habitat', + env_settings={ + # habitat sim specifications - agent, sensors, tasks, measures etc. are defined in the habitat config file + 'config_path': 'scripts/eval/configs/vln_r2r.yaml', + }, + ), + eval_type='habitat_vln', + eval_settings={ + # all current parse args + "output_path": "./logs/habitat/test_s2", # output directory for logs/results + "save_video": False, # whether to save videos + "epoch": 0, # epoch number for logging + "max_steps_per_episode": 500, # maximum steps per episode + # distributed settings + "port": "2333", # communication port + "dist_url": "env://", # url for distributed setup + }, +) diff --git a/scripts/eval/configs/vln_r2r.yaml b/scripts/eval/configs/vln_r2r.yaml index ed8361c..e6b1a89 100644 --- a/scripts/eval/configs/vln_r2r.yaml +++ b/scripts/eval/configs/vln_r2r.yaml @@ -72,6 +72,6 @@ habitat: dataset: type: R2RVLN-v1 - split: val_seen + split: val_unseen scenes_dir: data/scene_data/mp3d_ce data_path: data/vln_ce/raw_data/r2r/{split}/{split}.json.gz diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 68c507a..4c17026 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -5,7 +5,6 @@ import argparse import importlib.util -from internnav.configs.evaluator.vln_default_config import get_config from internnav.evaluator import Evaluator # This file is the main file @@ -33,10 +32,15 @@ def load_eval_cfg(config_path, attr_name='eval_cfg'): def main(): args = parse_args() evaluator_cfg = load_eval_cfg(args.config, attr_name='eval_cfg') - cfg = get_config(evaluator_cfg) - print(cfg) - evaluator = Evaluator.init(cfg) - print(type(evaluator)) + + # fill in evaluator default config + if evaluator_cfg.eval_type == 'vln_distributed': + from internnav.configs.evaluator.vln_default_config import get_config + + evaluator_cfg = get_config(evaluator_cfg) + + # create evaluator based on sim backend and run eval + evaluator = Evaluator.init(evaluator_cfg) evaluator.eval() diff --git a/scripts/eval/eval_habitat.py b/scripts/eval/eval_habitat.py deleted file mode 100644 index e78a8d6..0000000 --- a/scripts/eval/eval_habitat.py +++ /dev/null @@ -1,127 +0,0 @@ -import argparse -import json -import os -import sys - -sys.path.append('./src/diffusion-policy') - -import numpy as np -import torch -from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration - -# Import for Habitat registry side effects — do not remove -import internnav.env.utils.habitat_extensions.measures # noqa: F401 -from internnav.evaluator.habitat_vln_evaluator import VLNEvaluator -from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM -from internnav.utils.dist import * - - -def parse_args(): - - parser = argparse.ArgumentParser(description='Evaluate InternVLA-N1 on Habitat') - parser.add_argument("--mode", default='dual_system', type=str, help="inference mode: dual_system or system2") - parser.add_argument("--local_rank", default=0, type=int, help="node rank") - parser.add_argument("--model_path", type=str, default="") - parser.add_argument("--habitat_config_path", type=str, default='scripts/eval/configs/vln_r2r.yaml') - parser.add_argument("--eval_split", type=str, default='val_unseen') - parser.add_argument("--output_path", type=str, default='./logs/habitat/test') #! - parser.add_argument("--num_future_steps", type=int, default=4) - parser.add_argument("--num_frames", type=int, default=32) - parser.add_argument("--save_video", action="store_true", default=False) - parser.add_argument("--num_history", type=int, default=8) - parser.add_argument("--resize_w", type=int, default=384) - parser.add_argument("--resize_h", type=int, default=384) - parser.add_argument("--predict_step_nums", type=int, default=16) - parser.add_argument("--continuous_traj", action="store_true", default=False) - parser.add_argument("--max_new_tokens", type=int, default=1024) - - parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') - parser.add_argument('--rank', default=0, type=int, help='rank') - parser.add_argument('--gpu', default=0, type=int, help='gpu') - parser.add_argument('--port', default='2333') - parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') - parser.add_argument('--device', default='cuda', help='device to use for training / testing') - - return parser.parse_args() - - -def main(): - args = parse_args() - - init_distributed_mode(args) - local_rank = args.local_rank - np.random.seed(local_rank) - - # * 1. Load model and tokenizer. Currently, we support two modes: dual_system and system2 in Habitat. - processor = AutoProcessor.from_pretrained(args.model_path) - processor.tokenizer.padding_side = 'left' - - device = torch.device(f"cuda:{local_rank}") - if args.mode == 'dual_system': - model = InternVLAN1ForCausalLM.from_pretrained( - args.model_path, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - device_map={"": device}, - ) - elif args.mode == 'system2': - model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - args.model_path, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - device_map={"": device}, - ) - else: - raise ValueError(f"Invalid mode: {args.mode}") - - model.eval() - world_size = get_world_size() - - # * 2. initialize evaluator - evaluator = VLNEvaluator( - config_path=args.habitat_config_path, - split=args.eval_split, - env_num=world_size, - output_path=args.output_path, - model=model, - processor=processor, - epoch=0, - args=args, - ) - - # * 3. do eval - sucs, spls, oss, nes, ep_num = evaluator.eval_action(idx=get_rank()) - ep_num_all = [torch.zeros_like(ep_num) for _ in range(world_size)] - - # import ipdb; ipdb.set_trace() - dist.all_gather(ep_num_all, ep_num) - sucs_all = [torch.zeros(ep_num_all[i], dtype=sucs.dtype).to(sucs.device) for i in range(world_size)] - spls_all = [torch.zeros(ep_num_all[i], dtype=spls.dtype).to(spls.device) for i in range(world_size)] - oss_all = [torch.zeros(ep_num_all[i], dtype=oss.dtype).to(oss.device) for i in range(world_size)] - nes_all = [torch.zeros(ep_num_all[i], dtype=nes.dtype).to(nes.device) for i in range(world_size)] - dist.barrier() - dist.all_gather(sucs_all, sucs) - dist.all_gather(spls_all, spls) - dist.all_gather(oss_all, oss) - dist.all_gather(nes_all, nes) - - sucs_all = torch.cat(sucs_all, dim=0) - spls_all = torch.cat(spls_all, dim=0) - oss_all = torch.cat(oss_all, dim=0) - nes_all = torch.cat(nes_all, dim=0) - result_all = { - "sucs_all": (sum(sucs_all) / len(sucs_all)).item(), - "spls_all": (sum(spls_all) / len(spls_all)).item(), - "oss_all": (sum(oss_all) / len(oss_all)).item(), - "nes_all": (sum(nes_all) / len(nes_all)).item(), - 'length': len(sucs_all), - } - - print(result_all) - if get_rank() == 0: - with open(os.path.join(args.output_path, f'result.json'), 'a') as f: - f.write(json.dumps(result_all)) - - -if __name__ == '__main__': - main() diff --git a/scripts/eval/start_server.py b/scripts/eval/start_server.py index 5a03cd7..60e6007 100644 --- a/scripts/eval/start_server.py +++ b/scripts/eval/start_server.py @@ -5,27 +5,15 @@ sys.path.append('./src/diffusion-policy') import argparse -import glob import importlib import importlib.util -import os import sys +# Import for agent registry side effects — do not remove +from internnav.agent import Agent # noqa: F401 from internnav.utils import AgentServer -# import all agents to register them -def auto_register_agents(agent_dir: str): - # Get all Python files in the agents directory - agent_modules = glob.glob(os.path.join(agent_dir, '*.py')) - - # Import each module to trigger the registration - for module in agent_modules: - if not module.endswith('__init__.py'): # Avoid importing __init__.py itself - module_name = os.path.basename(module)[:-3] # Remove the .py extension - importlib.import_module(f'internnav.agent.{module_name}') # Replace 'agents' with your module's package - - def load_eval_cfg(config_path, attr_name='eval_cfg'): spec = importlib.util.spec_from_file_location("eval_config_module", config_path) config_module = importlib.util.module_from_spec(spec) @@ -37,9 +25,6 @@ def load_eval_cfg(config_path, attr_name='eval_cfg'): if __name__ == '__main__': print("Starting Agent Server...") - print("Registering agents...") - auto_register_agents('internnav/agent') - parser = argparse.ArgumentParser() parser.add_argument('--host', type=str, default='localhost') parser.add_argument( diff --git a/scripts/iros_challenge/start_eval_iros.sh b/scripts/iros_challenge/start_eval_iros.sh index d5f43ac..5701241 100755 --- a/scripts/iros_challenge/start_eval_iros.sh +++ b/scripts/iros_challenge/start_eval_iros.sh @@ -40,14 +40,14 @@ mkdir -p logs SERVER_LOG="logs/${CONFIG_PREFIX}_server.log" EVAL_LOG="logs/${CONFIG_PREFIX}_eval.log" -processes=$(ps -ef | grep 'internnav/agent/utils/server.py' | grep -v grep | awk '{print $2}') +processes=$(ps -ef | grep 'scripts/eval/start_server.py' | grep -v grep | awk '{print $2}') if [ -n "$processes" ]; then for pid in $processes; do kill -9 $pid echo "kill: $pid" done fi -python internnav/agent/utils/server.py --config scripts/eval/configs/challenge_cfg.py > "$SERVER_LOG" 2>&1 & +python scripts/eval/start_server.py --config scripts/eval/configs/challenge_cfg.py > "$SERVER_LOG" 2>&1 & START_COMMAND_KUJIALE="python -u scripts/eval/eval_iros.py --config $CONFIG --default_config scripts/eval/configs/challenge_kujiale_cfg.py --split $SPLIT" diff --git a/setup.cfg b/setup.cfg index 3aeaebe..dd5d8f9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ extra_standard_library = pkg_resources,setuptools known_first_party = internutopia, internutopia_extension, grevaluator, grbench, grmodel no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY -skip_glob = internutopia/*, internutopia_extension/* +skip_glob = internutopia/*, internutopia_extension/*, internnav/scripts/eval/configs/* # ignore-words-list needs to be lowercase format. For example, if we want to @@ -45,4 +45,4 @@ per-file-ignores=*/__init__.py:F401 ignore=E402,E501,W503,E203,D401,R504,R505,SIM102,SIM117,E711,E226 max-line-length = 120 max-complexity = 30 -exclude=_*,.vscode,.git,docs/**,**/test/**,**/lcmtypes/**,*.ipynb,scripts/**,internnav/projects/** +exclude=_*,.vscode,.git,docs/**,**/test/**,**/lcmtypes/**,*.ipynb diff --git a/tests/function_test/e2e_test.py b/tests/function_test/e2e_test.py index 023508c..d50e9a6 100644 --- a/tests/function_test/e2e_test.py +++ b/tests/function_test/e2e_test.py @@ -37,7 +37,7 @@ def teardown_function(function): if os.path.exists('./test_result.json'): case_info = {} test_name = function.__name__ - case_info['case_info'] = test_name + '_' + os.environ.get('JOB_ID') + case_info['case_info'] = f"{test_name}_{os.environ.get('JOB_ID', 'local')}" update_jsonl_from_json('./test_result.json', '../total_result.jsonl', case_info) else: print('Warning! There is no test_result.json') diff --git a/tests/function_test/test_evaluator.py b/tests/function_test/test_evaluator.py index 155a3b2..36e40d9 100644 --- a/tests/function_test/test_evaluator.py +++ b/tests/function_test/test_evaluator.py @@ -49,8 +49,8 @@ task=TaskCfg( task_name='test_evaluation', task_settings={ - 'env_num': 2, - 'use_distributed': True, # Ray distributed framework + 'env_num': 1, + 'use_distributed': False, # Ray distributed framework 'proc_num': 4, }, scene=SceneCfg( @@ -70,7 +70,12 @@ 'filter_stairs': False, }, ), - eval_settings={'save_to_json': False, 'vis_output': False}, # save result to video under logs/ + eval_type='vln_distributed', + eval_settings={ + 'save_to_json': False, + 'vis_output': False, + 'use_agent_server': True, + }, # save result to video under logs/ )