Files
Magic123/nerf/clip.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

128 lines
4.1 KiB
Python

import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.transforms.functional as TF
# import clip
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
from torchvision import transforms
import torch.nn.functional as F
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
# print(x.shape, y.shape)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
class CLIP(nn.Module):
def __init__(self, device,
# clip_name = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
clip_name = 'openai/clip-vit-large-patch14'
):
super().__init__()
self.device = device
clip_name = clip_name
# self.feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_name)
self.clip_model = CLIPModel.from_pretrained(clip_name).cuda()
self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
# self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# self.normalize = transforms.Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
# self.resize = transforms.Resize(224)
# # image augmentation
# self.aug = T.Compose([
# T.Resize((224, 224)),
# T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
# ])
def get_text_embeds(self, prompt, neg_prompt=None, dir=None):
clip_text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.cuda()
text_z = self.clip_model.get_text_features(clip_text_input)
# text = clip.tokenize(prompt).to(self.device)
# text_z = self.clip_model.encode_text(text)
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
return text_z
def set_epoch(self, epoch):
pass
def get_img_embeds(self, img):
img = self.aug(img)
image_z = self.clip_model.get_image_features(img)
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
return image_z
# def train_step(self, text_z, pred_rgb, image_ref_clip, **kwargs):
# pred_rgb = self.resize(pred_rgb)
# pred_rgb = self.normalize(pred_rgb)
# image_z = self.clip_model.get_image_features(pred_rgb)
# image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
# # print(image_z.shape, text_z.shape)
# loss = spherical_dist_loss(image_z, image_ref_clip)
# # loss = - (image_z * text_z).sum(-1).mean()
# return loss
def train_step(self, text_z, pred_rgb):
pred_rgb = self.aug(pred_rgb)
image_z = self.clip_model.encode_image(pred_rgb)
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
loss = - (image_z * text_z).sum(-1).mean()
# loss = spherical_dist_loss(image_z, text_z)
return loss
def text_loss(self, text_z, pred_rgb):
pred_rgb = self.resize(pred_rgb)
pred_rgb = self.normalize(pred_rgb)
image_z = self.clip_model.get_image_features(pred_rgb)
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
# print(image_z.shape, text_z.shape)
loss = spherical_dist_loss(image_z, text_z)
# loss = - (image_z * text_z).sum(-1).mean()
return loss
def img_loss(self, img_ref_z, pred_rgb):
# pred_rgb = self.aug(pred_rgb)
pred_rgb = self.resize(pred_rgb)
pred_rgb = self.normalize(pred_rgb)
image_z = self.clip_model.get_image_features(pred_rgb)
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
# loss = - (image_z * img_ref_z).sum(-1).mean()
loss = spherical_dist_loss(image_z, img_ref_z)
return loss