Skip to content

A question about how the clip model works #6

@2019211753

Description

@2019211753

In vlmrm, cosine similarity acts as a reward to guide the rl model, how does the clip model do this?
I did a toy experiment to test the ability of the text features of a clip model:

import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.transforms.functional as Ftv
from tqdm.auto import tqdm
import open_clip
from PIL import Image
a = open_clip.list_pretrained()
torch.manual_seed(2024)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epoch = 500

image = torch.randn((1, 3, 224, 224), device=device, requires_grad=True)
init_lr = 0.1
optimizer = torch.optim.Adam([image], lr=init_lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

clip_guidance_model, _, processor = open_clip.create_model_and_transforms('ViT-bigG-14', pretrained='laion2b_s39b_b160k')
clip_guidance_model.eval()
clip_guidance_model.to(device)
clip_guidance_model.requires_grad_(False)
tokenizer = open_clip.get_tokenizer('ViT-bigG-14')

prompt = ['a dog']
text = open_clip.tokenize(prompt).to(device)

with torch.no_grad():
    text_features = clip_guidance_model.encode_text(text)

def total_variation_loss(img):
    return torch.sum(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])) + torch.sum(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]))

def clip_loss(img, text_features, tv_weight=1e-6):
    mean = [0.48, 0.45, 0.40]
    std = [0.268, 0.261, 0.275]
    interpolation = transforms.InterpolationMode.BICUBIC
    size = 224

    processed = Ftv.normalize(
        Ftv.center_crop(
            Ftv.resize(img, size, interpolation=interpolation),
            size
        ),
        mean, std, inplace=False
    )
    image_features = clip_guidance_model.encode_image(processed)

    input_normed = F.normalize(image_features, dim=1)
    embed_normed = F.normalize(text_features, dim=1)

    cosine_sim = (input_normed @ embed_normed.T).squeeze()

    loss = -cosine_sim.mean()

    # loss += tv_weight * total_variation_loss(img)
    return loss

for step in tqdm(range(epoch)):
    optimizer.zero_grad()
    loss = clip_loss(image, text_features)
    loss.backward()

    optimizer.step()
    scheduler.step() 

    if (step + 1) % 50 == 0 or step == 0:
        print(f'timestep:{step}, loss:{loss.item()}')

with torch.no_grad():
    final_image = image.clone().detach().cpu().squeeze().permute(1, 2, 0)
    final_image = torch.clamp(final_image, 0, 1)
    final_image = (final_image * 255).to(torch.uint8).numpy()
    final_image = Image.fromarray(final_image)
    final_image.save(f'results/{prompt[0]}_CLIP.png')

I want to generate a dog picture, however, when the traning is over,

  0%|          | 1/500 [00:00<04:21,  1.91it/s]timestep:0, loss:-0.2738375663757324
 10%|| 50/500 [00:08<01:23,  5.40it/s]timestep:49, loss:-0.8890969753265381
 20%|██        | 100/500 [00:16<01:14,  5.38it/s]timestep:99, loss:-0.9663792252540588
 30%|███       | 150/500 [00:25<01:05,  5.38it/s]timestep:149, loss:-0.9884988069534302
 40%|████      | 200/500 [00:33<00:56,  5.33it/s]timestep:199, loss:-0.993033230304718
 50%|█████     | 250/500 [00:42<00:46,  5.34it/s]timestep:249, loss:-0.995253324508667
 60%|██████    | 300/500 [00:50<00:37,  5.28it/s]timestep:299, loss:-0.9952336549758911
 70%|███████   | 350/500 [00:59<00:28,  5.26it/s]timestep:349, loss:-0.9962178468704224
 80%|████████  | 400/500 [01:07<00:19,  5.24it/s]timestep:399, loss:-0.9964497089385986
 90%|█████████ | 450/500 [01:16<00:09,  5.16it/s]timestep:449, loss:-0.9965552687644958
100%|██████████| 500/500 [01:25<00:00,  5.87it/s]
timestep:499, loss:-0.9966541528701782

the result is still a noise:
dea7392d8b9434b731025a830eb6ac2
Could u please help me understand it?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions