diff --git a/README.md b/README.md index e364ce4..695deb9 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,9 @@ pip install -r requirements.txt # Run the Track-Anything gradio demo. python app.py --device cuda:0 + +# If your platform AppleM2 use +python app.py --device mps # python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage ``` diff --git a/app.py b/app.py index 870fae0..a6aef87 100644 --- a/app.py +++ b/app.py @@ -114,6 +114,7 @@ def get_frames_from_video(video_input, video_state): video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size) model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) + print(video_info) return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True), \ @@ -336,18 +337,24 @@ def generate_video_from_frames(frames, output_path, fps=30): output_path (str): The path to save the generated video. fps (int, optional): The frame rate of the output video. Defaults to 30. """ - # height, width, layers = frames[0].shape - # fourcc = cv2.VideoWriter_fourcc(*"mp4v") - # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) - # print(output_path) - # for frame in frames: - # video.write(frame) - + height, width, layers = frames[0].shape + print(f"Video width: {width}, height: {height}") + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + video = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) + print(output_path) + for frame in frames: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + video.write(frame) + # zhaifang add + ''' + height, width, layers = frames[0].shape + print(f"Video width: {width}, height: {height}") # video.release() frames = torch.from_numpy(np.asarray(frames)) if not os.path.exists(os.path.dirname(output_path)): os.makedirs(os.path.dirname(output_path)) torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264") + ''' return output_path @@ -377,8 +384,8 @@ def generate_video_from_frames(frames, output_path, fps=30): SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint) -args.port = 12212 -args.device = "cuda:3" +args.port = 7860 +args.device = "mps" # args.mask_save = True # initialize sam, xmem, e2fgvi models @@ -428,8 +435,8 @@ def generate_video_from_frames(frames, output_path, fps=30): # for user video input with gr.Column(): - with gr.Row(scale=0.4): - video_input = gr.Video(autosize=True) + with gr.Row():#scale=0.4 + video_input = gr.Video()#autosize=True with gr.Column(): video_info = gr.Textbox(label="Video Info") resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \ @@ -454,16 +461,17 @@ def generate_video_from_frames(frames, output_path, fps=30): interactive=True, visible=False) remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False) - clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False).style(height=160) + #clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False).style(height=160) + clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False) Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False) - template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360) + template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False) image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False) track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False) with gr.Column(): run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False) mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False) - video_output = gr.Video(autosize=True, visible=False).style(height=360) + video_output = gr.Video(autoplay=True, visible=False) with gr.Row(): tracking_video_predict_button = gr.Button(value="Tracking", visible=False) inpaint_video_predict_button = gr.Button(value="Inpainting", visible=False) @@ -583,13 +591,13 @@ def generate_video_from_frames(frames, output_path, fps=30): clear_button_click.click( fn = clear_click, inputs = [video_state, click_state,], - outputs = [template_frame,click_state, run_status], + outputs = [template_frame,click_state, run_status] ) # set example gr.Markdown("## Examples") gr.Examples( examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \ - "test-sample2.mp4","test-sample13.mp4"]], + "test-sample2.mp4","test-sample13.mp4", "RGB_video.mp4"]], fn=run_example, inputs=[ video_input @@ -597,6 +605,5 @@ def generate_video_from_frames(frames, output_path, fps=30): outputs=[video_input], # cache_examples=True, ) -iface.queue(concurrency_count=1) -iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0") -# iface.launch(debug=True, enable_queue=True) \ No newline at end of file + +iface.launch(debug=True, server_port=args.port, server_name="127.0.0.1",max_threads=1,share=True) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8078c88 --- /dev/null +++ b/setup.py @@ -0,0 +1,20 @@ +from setuptools import setup, find_packages + +setup( + name="tracker", + version="0.2.1", + packages=find_packages(), + install_requires=[], + author="zhaifang", + author_email="zhaifang@tsinghua.edu.cn", + description="xmem tracking for 3 sensor short long-term memory", + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + url="git@github.com:bingxinhu/Track-Anything.git", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires='>=3.11', +) diff --git a/tools/base_segmenter.py b/tools/base_segmenter.py index 2b975bb..7ac4615 100644 --- a/tools/base_segmenter.py +++ b/tools/base_segmenter.py @@ -11,7 +11,7 @@ class BaseSegmenter: - def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): + def __init__(self, SAM_checkpoint, model_type, device='mps'): """ device: model device SAM_checkpoint: path of SAM checkpoint @@ -85,7 +85,7 @@ def predict(self, prompts, mode, multimask=True): # initialise BaseSegmenter SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' model_type = 'vit_h' - device = "cuda:4" + device = "mps" base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device) # image embedding (once embedded, multiple prompts can be applied) diff --git a/track_anything.py b/track_anything.py index 5275252..34c80c9 100644 --- a/track_anything.py +++ b/track_anything.py @@ -61,11 +61,11 @@ def generator(self, images: list, template_mask:np.ndarray): def parse_augment(): parser = argparse.ArgumentParser() - parser.add_argument('--device', type=str, default="cuda:0") + parser.add_argument('--device', type=str, default="mps") parser.add_argument('--sam_model_type', type=str, default="vit_h") parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications") parser.add_argument('--debug', action="store_true") - parser.add_argument('--mask_save', default=False) + parser.add_argument('--mask_save', default=True) args = parser.parse_args() if args.debug: @@ -78,7 +78,7 @@ def parse_augment(): logits = None painted_images = None images = [] - image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg')) + image = np.array(PIL.Image.open('./img/dogs.jpg')) args = parse_augment() # images.append(np.ones((20,20,3)).astype('uint8')) # images.append(np.ones((20,20,3)).astype('uint8')) @@ -87,10 +87,11 @@ def parse_augment(): mask = np.zeros_like(image)[:,:,0] mask[0,0]= 1 - trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args) + trackany = TrackingAnything('./checkpoints/sam_vit_h_4b8939.pth','./checkpoints/XMem-s012.pth', './checkpoints/E2FGVI-HQ-CVPR22.pth', args) masks, logits ,painted_images= trackany.generator(images, mask) + - \ No newline at end of file + diff --git a/tracker/__init__.py b/tracker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tracker/__pycache__/__init__.cpython-311.pyc b/tracker/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..0e357cf Binary files /dev/null and b/tracker/__pycache__/__init__.cpython-311.pyc differ diff --git a/tracker/__pycache__/base_tracker.cpython-311.pyc b/tracker/__pycache__/base_tracker.cpython-311.pyc new file mode 100644 index 0000000..bf32544 Binary files /dev/null and b/tracker/__pycache__/base_tracker.cpython-311.pyc differ diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py index 8c4ee02..4fb08a8 100644 --- a/tracker/base_tracker.py +++ b/tracker/base_tracker.py @@ -7,8 +7,9 @@ import torch import yaml import torch.nn.functional as F +from tracker.inference.inference_core import InferenceCore from tracker.model.network import XMem -from inference.inference_core import InferenceCore + from tracker.util.mask_mapper import MaskMapper from torchvision import transforms from tracker.util.range_transform import im_normalization @@ -25,6 +26,11 @@ def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) -> device: model device xmem_checkpoint: checkpoint of XMem model """ + if device is None: + if torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # load configurations with open("tracker/config/config.yaml", 'r') as stream: config = yaml.safe_load(stream) @@ -103,7 +109,7 @@ def track(self, frame, first_frame_annotation=None): # print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB') - return final_mask, final_mask, painted_image + return final_mask, probs, painted_image @torch.no_grad() def sam_refinement(self, frame, logits, ti): @@ -126,8 +132,11 @@ def sam_refinement(self, frame, logits, ti): def clear_memory(self): self.tracker.clear_memory() self.mapper.clear_labels() - torch.cuda.empty_cache() - + if self.device == "cuda": + torch.cuda.empty_cache() + if self.device == "mps": + torch.mps.empty_cache() + ## how to use: ## 1/3) prepare device and xmem_checkpoint @@ -155,7 +164,7 @@ def clear_memory(self): # how to use # ------------------------------------------------------------------------------------ # 1/4: set checkpoint and device - device = 'cuda:2' + device = 'mps' XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' # SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' # model_type = 'vit_h' @@ -179,7 +188,10 @@ def clear_memory(self): # ---------------------------------------------- # end # ---------------------------------------------- - print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB') + if device == "cuda": + print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB') + if device == "mps": + print(f'max memory allocated: {torch.mps.driver_allocated_memory()/(2**20)} MB') # set saving path save_path = '/ssd1/gaomingqi/results/TAM/blackswan' if not os.path.exists(save_path): diff --git a/tracker/inference/__pycache__/__init__.cpython-311.pyc b/tracker/inference/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..34ecfdf Binary files /dev/null and b/tracker/inference/__pycache__/__init__.cpython-311.pyc differ diff --git a/tracker/inference/__pycache__/inference_core.cpython-311.pyc b/tracker/inference/__pycache__/inference_core.cpython-311.pyc new file mode 100644 index 0000000..9da6b84 Binary files /dev/null and b/tracker/inference/__pycache__/inference_core.cpython-311.pyc differ diff --git a/tracker/inference/__pycache__/kv_memory_store.cpython-311.pyc b/tracker/inference/__pycache__/kv_memory_store.cpython-311.pyc new file mode 100644 index 0000000..26d6ee8 Binary files /dev/null and b/tracker/inference/__pycache__/kv_memory_store.cpython-311.pyc differ diff --git a/tracker/inference/__pycache__/memory_manager.cpython-311.pyc b/tracker/inference/__pycache__/memory_manager.cpython-311.pyc new file mode 100644 index 0000000..988b65f Binary files /dev/null and b/tracker/inference/__pycache__/memory_manager.cpython-311.pyc differ diff --git a/tracker/inference/inference_core.py b/tracker/inference/inference_core.py index e77f080..0b88e64 100644 --- a/tracker/inference/inference_core.py +++ b/tracker/inference/inference_core.py @@ -1,7 +1,8 @@ -from inference.memory_manager import MemoryManager -from model.network import XMem -from model.aggregate import aggregate +from tracker.inference.memory_manager import MemoryManager + +from tracker.model.aggregate import aggregate +from tracker.model.network import XMem from tracker.util.tensor_util import pad_divide_by, unpad diff --git a/tracker/inference/memory_manager.py b/tracker/inference/memory_manager.py index d47d96e..32bc107 100644 --- a/tracker/inference/memory_manager.py +++ b/tracker/inference/memory_manager.py @@ -1,8 +1,8 @@ import torch import warnings -from inference.kv_memory_store import KeyValueMemoryStore -from model.memory_util import * +from tracker.model.memory_util import * +from tracker.inference.kv_memory_store import KeyValueMemoryStore class MemoryManager: diff --git a/tracker/model/__pycache__/__init__.cpython-311.pyc b/tracker/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..6e0d521 Binary files /dev/null and b/tracker/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/tracker/model/__pycache__/aggregate.cpython-311.pyc b/tracker/model/__pycache__/aggregate.cpython-311.pyc new file mode 100644 index 0000000..798d6ef Binary files /dev/null and b/tracker/model/__pycache__/aggregate.cpython-311.pyc differ diff --git a/tracker/model/__pycache__/cbam.cpython-311.pyc b/tracker/model/__pycache__/cbam.cpython-311.pyc new file mode 100644 index 0000000..dd7c324 Binary files /dev/null and b/tracker/model/__pycache__/cbam.cpython-311.pyc differ diff --git a/tracker/model/__pycache__/group_modules.cpython-311.pyc b/tracker/model/__pycache__/group_modules.cpython-311.pyc new file mode 100644 index 0000000..646863f Binary files /dev/null and b/tracker/model/__pycache__/group_modules.cpython-311.pyc differ diff --git a/tracker/model/__pycache__/memory_util.cpython-311.pyc b/tracker/model/__pycache__/memory_util.cpython-311.pyc new file mode 100644 index 0000000..747c5b5 Binary files /dev/null and b/tracker/model/__pycache__/memory_util.cpython-311.pyc differ diff --git a/tracker/model/__pycache__/modules.cpython-311.pyc b/tracker/model/__pycache__/modules.cpython-311.pyc new file mode 100644 index 0000000..8c555c2 Binary files /dev/null and b/tracker/model/__pycache__/modules.cpython-311.pyc differ diff --git a/tracker/model/__pycache__/network.cpython-311.pyc b/tracker/model/__pycache__/network.cpython-311.pyc new file mode 100644 index 0000000..152e403 Binary files /dev/null and b/tracker/model/__pycache__/network.cpython-311.pyc differ diff --git a/tracker/model/__pycache__/resnet.cpython-311.pyc b/tracker/model/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000..c7e2f29 Binary files /dev/null and b/tracker/model/__pycache__/resnet.cpython-311.pyc differ diff --git a/tracker/model/modules.py b/tracker/model/modules.py index 9920799..8161c5a 100644 --- a/tracker/model/modules.py +++ b/tracker/model/modules.py @@ -14,9 +14,11 @@ import torch.nn as nn import torch.nn.functional as F -from model.group_modules import * -from model import resnet -from model.cbam import CBAM +from tracker.model import resnet +from tracker.model.cbam import CBAM +from tracker.model.group_modules import GConv2D, GroupResBlock, MainToGroupDistributor, downsample_groups, upsample_groups + + class FeatureFusionBlock(nn.Module): diff --git a/tracker/model/network.py b/tracker/model/network.py index 70b7e92..be1b4ec 100644 --- a/tracker/model/network.py +++ b/tracker/model/network.py @@ -8,10 +8,11 @@ import torch import torch.nn as nn +from tracker.model.aggregate import aggregate +from tracker.model.memory_util import get_affinity, readout +from tracker.model.modules import Decoder, KeyEncoder, KeyProjection, ValueEncoder + -from model.aggregate import aggregate -from model.modules import * -from model.memory_util import * class XMem(nn.Module): diff --git a/tracker/util/__pycache__/__init__.cpython-311.pyc b/tracker/util/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..c5ae92b Binary files /dev/null and b/tracker/util/__pycache__/__init__.cpython-311.pyc differ diff --git a/tracker/util/__pycache__/mask_mapper.cpython-311.pyc b/tracker/util/__pycache__/mask_mapper.cpython-311.pyc new file mode 100644 index 0000000..dae279b Binary files /dev/null and b/tracker/util/__pycache__/mask_mapper.cpython-311.pyc differ diff --git a/tracker/util/__pycache__/range_transform.cpython-311.pyc b/tracker/util/__pycache__/range_transform.cpython-311.pyc new file mode 100644 index 0000000..0a1fa46 Binary files /dev/null and b/tracker/util/__pycache__/range_transform.cpython-311.pyc differ diff --git a/tracker/util/__pycache__/tensor_util.cpython-311.pyc b/tracker/util/__pycache__/tensor_util.cpython-311.pyc new file mode 100644 index 0000000..20f0a3b Binary files /dev/null and b/tracker/util/__pycache__/tensor_util.cpython-311.pyc differ