first commit
This commit is contained in:
127
nerf/clip.py
Normal file
127
nerf/clip.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
# import clip
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
|
||||
from torchvision import transforms
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def spherical_dist_loss(x, y):
|
||||
x = F.normalize(x, dim=-1)
|
||||
y = F.normalize(y, dim=-1)
|
||||
# print(x.shape, y.shape)
|
||||
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self, device,
|
||||
# clip_name = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
|
||||
clip_name = 'openai/clip-vit-large-patch14'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
clip_name = clip_name
|
||||
|
||||
# self.feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_name)
|
||||
self.clip_model = CLIPModel.from_pretrained(clip_name).cuda()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
|
||||
# self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
|
||||
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
# self.normalize = transforms.Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
|
||||
|
||||
# self.resize = transforms.Resize(224)
|
||||
|
||||
# # image augmentation
|
||||
# self.aug = T.Compose([
|
||||
# T.Resize((224, 224)),
|
||||
# T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
# ])
|
||||
|
||||
|
||||
def get_text_embeds(self, prompt, neg_prompt=None, dir=None):
|
||||
|
||||
clip_text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids.cuda()
|
||||
text_z = self.clip_model.get_text_features(clip_text_input)
|
||||
# text = clip.tokenize(prompt).to(self.device)
|
||||
# text_z = self.clip_model.encode_text(text)
|
||||
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
return text_z
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
pass
|
||||
|
||||
def get_img_embeds(self, img):
|
||||
img = self.aug(img)
|
||||
image_z = self.clip_model.get_image_features(img)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
return image_z
|
||||
|
||||
|
||||
# def train_step(self, text_z, pred_rgb, image_ref_clip, **kwargs):
|
||||
|
||||
# pred_rgb = self.resize(pred_rgb)
|
||||
# pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
# image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
# image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# # print(image_z.shape, text_z.shape)
|
||||
# loss = spherical_dist_loss(image_z, image_ref_clip)
|
||||
|
||||
# # loss = - (image_z * text_z).sum(-1).mean()
|
||||
|
||||
# return loss
|
||||
|
||||
def train_step(self, text_z, pred_rgb):
|
||||
|
||||
pred_rgb = self.aug(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.encode_image(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
loss = - (image_z * text_z).sum(-1).mean()
|
||||
# loss = spherical_dist_loss(image_z, text_z)
|
||||
return loss
|
||||
|
||||
def text_loss(self, text_z, pred_rgb):
|
||||
|
||||
pred_rgb = self.resize(pred_rgb)
|
||||
pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# print(image_z.shape, text_z.shape)
|
||||
loss = spherical_dist_loss(image_z, text_z)
|
||||
|
||||
# loss = - (image_z * text_z).sum(-1).mean()
|
||||
|
||||
return loss
|
||||
|
||||
def img_loss(self, img_ref_z, pred_rgb):
|
||||
# pred_rgb = self.aug(pred_rgb)
|
||||
pred_rgb = self.resize(pred_rgb)
|
||||
pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# loss = - (image_z * img_ref_z).sum(-1).mean()
|
||||
loss = spherical_dist_loss(image_z, img_ref_z)
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user