first commit
This commit is contained in:
248
taming/modules/transformer/permuter.py
Normal file
248
taming/modules/transformer/permuter.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractPermuter(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
def forward(self, x, reverse=False):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Identity(AbstractPermuter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
return x
|
||||
|
||||
|
||||
class Subsample(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
C = 1
|
||||
indices = np.arange(H*W).reshape(C,H,W)
|
||||
while min(H, W) > 1:
|
||||
indices = indices.reshape(C,H//2,2,W//2,2)
|
||||
indices = indices.transpose(0,2,4,1,3)
|
||||
indices = indices.reshape(C*4,H//2, W//2)
|
||||
H = H//2
|
||||
W = W//2
|
||||
C = C*4
|
||||
assert H == W == 1
|
||||
idx = torch.tensor(indices.ravel())
|
||||
self.register_buffer('forward_shuffle_idx',
|
||||
nn.Parameter(idx, requires_grad=False))
|
||||
self.register_buffer('backward_shuffle_idx',
|
||||
nn.Parameter(torch.argsort(idx), requires_grad=False))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
def mortonify(i, j):
|
||||
"""(i,j) index to linear morton code"""
|
||||
i = np.uint64(i)
|
||||
j = np.uint64(j)
|
||||
|
||||
z = np.uint(0)
|
||||
|
||||
for pos in range(32):
|
||||
z = (z |
|
||||
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
|
||||
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
|
||||
)
|
||||
return z
|
||||
|
||||
|
||||
class ZCurve(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
|
||||
idx = np.argsort(reverseidx)
|
||||
idx = torch.tensor(idx)
|
||||
reverseidx = torch.tensor(reverseidx)
|
||||
self.register_buffer('forward_shuffle_idx',
|
||||
idx)
|
||||
self.register_buffer('backward_shuffle_idx',
|
||||
reverseidx)
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
class SpiralOut(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
assert H == W
|
||||
size = W
|
||||
indices = np.arange(size*size).reshape(size,size)
|
||||
|
||||
i0 = size//2
|
||||
j0 = size//2-1
|
||||
|
||||
i = i0
|
||||
j = j0
|
||||
|
||||
idx = [indices[i0, j0]]
|
||||
step_mult = 0
|
||||
for c in range(1, size//2+1):
|
||||
step_mult += 1
|
||||
# steps left
|
||||
for k in range(step_mult):
|
||||
i = i - 1
|
||||
j = j
|
||||
idx.append(indices[i, j])
|
||||
|
||||
# step down
|
||||
for k in range(step_mult):
|
||||
i = i
|
||||
j = j + 1
|
||||
idx.append(indices[i, j])
|
||||
|
||||
step_mult += 1
|
||||
if c < size//2:
|
||||
# step right
|
||||
for k in range(step_mult):
|
||||
i = i + 1
|
||||
j = j
|
||||
idx.append(indices[i, j])
|
||||
|
||||
# step up
|
||||
for k in range(step_mult):
|
||||
i = i
|
||||
j = j - 1
|
||||
idx.append(indices[i, j])
|
||||
else:
|
||||
# end reached
|
||||
for k in range(step_mult-1):
|
||||
i = i + 1
|
||||
idx.append(indices[i, j])
|
||||
|
||||
assert len(idx) == size*size
|
||||
idx = torch.tensor(idx)
|
||||
self.register_buffer('forward_shuffle_idx', idx)
|
||||
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
class SpiralIn(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
assert H == W
|
||||
size = W
|
||||
indices = np.arange(size*size).reshape(size,size)
|
||||
|
||||
i0 = size//2
|
||||
j0 = size//2-1
|
||||
|
||||
i = i0
|
||||
j = j0
|
||||
|
||||
idx = [indices[i0, j0]]
|
||||
step_mult = 0
|
||||
for c in range(1, size//2+1):
|
||||
step_mult += 1
|
||||
# steps left
|
||||
for k in range(step_mult):
|
||||
i = i - 1
|
||||
j = j
|
||||
idx.append(indices[i, j])
|
||||
|
||||
# step down
|
||||
for k in range(step_mult):
|
||||
i = i
|
||||
j = j + 1
|
||||
idx.append(indices[i, j])
|
||||
|
||||
step_mult += 1
|
||||
if c < size//2:
|
||||
# step right
|
||||
for k in range(step_mult):
|
||||
i = i + 1
|
||||
j = j
|
||||
idx.append(indices[i, j])
|
||||
|
||||
# step up
|
||||
for k in range(step_mult):
|
||||
i = i
|
||||
j = j - 1
|
||||
idx.append(indices[i, j])
|
||||
else:
|
||||
# end reached
|
||||
for k in range(step_mult-1):
|
||||
i = i + 1
|
||||
idx.append(indices[i, j])
|
||||
|
||||
assert len(idx) == size*size
|
||||
idx = idx[::-1]
|
||||
idx = torch.tensor(idx)
|
||||
self.register_buffer('forward_shuffle_idx', idx)
|
||||
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
class Random(nn.Module):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
indices = np.random.RandomState(1).permutation(H*W)
|
||||
idx = torch.tensor(indices.ravel())
|
||||
self.register_buffer('forward_shuffle_idx', idx)
|
||||
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
class AlternateParsing(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
indices = np.arange(W*H).reshape(H,W)
|
||||
for i in range(1, H, 2):
|
||||
indices[i, :] = indices[i, ::-1]
|
||||
idx = indices.flatten()
|
||||
assert len(idx) == H*W
|
||||
idx = torch.tensor(idx)
|
||||
self.register_buffer('forward_shuffle_idx', idx)
|
||||
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p0 = AlternateParsing(16, 16)
|
||||
print(p0.forward_shuffle_idx)
|
||||
print(p0.backward_shuffle_idx)
|
||||
|
||||
x = torch.randint(0, 768, size=(11, 256))
|
||||
y = p0(x)
|
||||
xre = p0(y, reverse=True)
|
||||
assert torch.equal(x, xre)
|
||||
|
||||
p1 = SpiralOut(2, 2)
|
||||
print(p1.forward_shuffle_idx)
|
||||
print(p1.backward_shuffle_idx)
|
||||
Reference in New Issue
Block a user