diff --git a/.gitignore b/.gitignore index c678a5e..2b71dbb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] +.DS_Store \ No newline at end of file diff --git a/examples/gs3drecon.py b/examples/gs3drecon.py index a0e2993..2cf3114 100644 --- a/examples/gs3drecon.py +++ b/examples/gs3drecon.py @@ -1,7 +1,7 @@ +import open3d import torch import torch.nn as nn import torch.nn.functional as F -import open3d import numpy as np import math import os @@ -206,15 +206,24 @@ def forward(self, x): class Reconstruction3D: def __init__(self, dev): - self.cpuorgpu = "cpu" + self.device_type = "cpu" self.dm_zero_counter = 0 self.dm_zero = np.zeros((dev.imgw, dev.imgh)) pass - def load_nn(self, net_path, cpuorgpu): + def load_nn(self, net_path, device_type=None): + + # Automatically select device if not provided + if device_type is None: + if torch.cuda.is_available(): + device_type = 'cuda' + elif torch.backends.mps.is_available(): + device_type = 'mps' + else: + device_type = 'cpu' - self.cpuorgpu = cpuorgpu - device = torch.device(cpuorgpu) + self.device_type = device_type + device = torch.device(device_type) if not os.path.isfile(net_path): print('Error opening ', net_path, ' does not exist') @@ -223,11 +232,15 @@ def load_nn(self, net_path, cpuorgpu): net = RGB2NormNet().float().to(device) - if cpuorgpu=="cuda": + if device_type=="cuda": ### load weights on gpu # net.load_state_dict(torch.load(net_path)) checkpoint = torch.load(net_path, map_location=lambda storage, loc: storage.cuda(0)) net.load_state_dict(checkpoint['state_dict']) + elif device_type == "mps" : + ### load weights on mac m seriel + checkpoint = torch.load(net_path, map_location=lambda storage, loc: storage.mps()) + net.load_state_dict(checkpoint['state_dict']) else: ### load weights on cpu which were actually trained on gpu checkpoint = torch.load(net_path, map_location=lambda storage, loc: storage) @@ -281,7 +294,7 @@ def get_depthmap(self, frame, mask_markers, cm=None): # pxpos[:, 1] = pxpos[:, 1] / ((240 / imgw) * imgw) features = np.column_stack((rgb, pxpos)) - features = torch.from_numpy(features).float().to(self.cpuorgpu) + features = torch.from_numpy(features).float().to(self.device_type) with torch.no_grad(): self.net.eval() out = self.net(features) diff --git a/examples/gsdevice.py b/examples/gsdevice.py index 7bfee66..e0f869b 100644 --- a/examples/gsdevice.py +++ b/examples/gsdevice.py @@ -1,5 +1,6 @@ import cv2 import numpy as np +import platform import os import re @@ -17,6 +18,13 @@ def get_camera_id(camera_name): cam_num = None if os.name == 'nt': cam_num = find_cameras_windows(camera_name) + elif platform.system() == "Darwin": + import usb.core + devices = usb.core.find(find_all=True) + for idx, device in enumerate(devices): + if camera_name in device.product: + cam_num = idx + break else: for file in os.listdir("/sys/class/video4linux"): real_file = os.path.realpath("/sys/class/video4linux/" + file + "/name") @@ -28,7 +36,9 @@ def get_camera_id(camera_name): else: found = " " print("{} {} -> {}".format(found, file, name)) - + if cam_num is None: + print("ERROR! Can't Found Camera Device") + exit() return cam_num if os.name == 'nt': diff --git a/gelsight/gs3drecon.py b/gelsight/gs3drecon.py index a0e2993..2cf3114 100644 --- a/gelsight/gs3drecon.py +++ b/gelsight/gs3drecon.py @@ -1,7 +1,7 @@ +import open3d import torch import torch.nn as nn import torch.nn.functional as F -import open3d import numpy as np import math import os @@ -206,15 +206,24 @@ def forward(self, x): class Reconstruction3D: def __init__(self, dev): - self.cpuorgpu = "cpu" + self.device_type = "cpu" self.dm_zero_counter = 0 self.dm_zero = np.zeros((dev.imgw, dev.imgh)) pass - def load_nn(self, net_path, cpuorgpu): + def load_nn(self, net_path, device_type=None): + + # Automatically select device if not provided + if device_type is None: + if torch.cuda.is_available(): + device_type = 'cuda' + elif torch.backends.mps.is_available(): + device_type = 'mps' + else: + device_type = 'cpu' - self.cpuorgpu = cpuorgpu - device = torch.device(cpuorgpu) + self.device_type = device_type + device = torch.device(device_type) if not os.path.isfile(net_path): print('Error opening ', net_path, ' does not exist') @@ -223,11 +232,15 @@ def load_nn(self, net_path, cpuorgpu): net = RGB2NormNet().float().to(device) - if cpuorgpu=="cuda": + if device_type=="cuda": ### load weights on gpu # net.load_state_dict(torch.load(net_path)) checkpoint = torch.load(net_path, map_location=lambda storage, loc: storage.cuda(0)) net.load_state_dict(checkpoint['state_dict']) + elif device_type == "mps" : + ### load weights on mac m seriel + checkpoint = torch.load(net_path, map_location=lambda storage, loc: storage.mps()) + net.load_state_dict(checkpoint['state_dict']) else: ### load weights on cpu which were actually trained on gpu checkpoint = torch.load(net_path, map_location=lambda storage, loc: storage) @@ -281,7 +294,7 @@ def get_depthmap(self, frame, mask_markers, cm=None): # pxpos[:, 1] = pxpos[:, 1] / ((240 / imgw) * imgw) features = np.column_stack((rgb, pxpos)) - features = torch.from_numpy(features).float().to(self.cpuorgpu) + features = torch.from_numpy(features).float().to(self.device_type) with torch.no_grad(): self.net.eval() out = self.net(features) diff --git a/gelsight/gsdevice.py b/gelsight/gsdevice.py index 7bfee66..e0f869b 100644 --- a/gelsight/gsdevice.py +++ b/gelsight/gsdevice.py @@ -1,5 +1,6 @@ import cv2 import numpy as np +import platform import os import re @@ -17,6 +18,13 @@ def get_camera_id(camera_name): cam_num = None if os.name == 'nt': cam_num = find_cameras_windows(camera_name) + elif platform.system() == "Darwin": + import usb.core + devices = usb.core.find(find_all=True) + for idx, device in enumerate(devices): + if camera_name in device.product: + cam_num = idx + break else: for file in os.listdir("/sys/class/video4linux"): real_file = os.path.realpath("/sys/class/video4linux/" + file + "/name") @@ -28,7 +36,9 @@ def get_camera_id(camera_name): else: found = " " print("{} {} -> {}".format(found, file, name)) - + if cam_num is None: + print("ERROR! Can't Found Camera Device") + exit() return cam_num if os.name == 'nt':