32 lines
904 B
Python
32 lines
904 B
Python
import torch
|
|
|
|
class CoordStage(object):
|
|
def __init__(self, n_embed, down_factor):
|
|
self.n_embed = n_embed
|
|
self.down_factor = down_factor
|
|
|
|
def eval(self):
|
|
return self
|
|
|
|
def encode(self, c):
|
|
"""fake vqmodel interface"""
|
|
assert 0.0 <= c.min() and c.max() <= 1.0
|
|
b,ch,h,w = c.shape
|
|
assert ch == 1
|
|
|
|
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
|
|
mode="area")
|
|
c = c.clamp(0.0, 1.0)
|
|
c = self.n_embed*c
|
|
c_quant = c.round()
|
|
c_ind = c_quant.to(dtype=torch.long)
|
|
|
|
info = None, None, c_ind
|
|
return c_quant, None, info
|
|
|
|
def decode(self, c):
|
|
c = c/self.n_embed
|
|
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
|
|
mode="nearest")
|
|
return c
|