first commit
This commit is contained in:
39
midas/backbones/next_vit.py
Normal file
39
midas/backbones/next_vit.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import timm
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from pathlib import Path
|
||||
from .utils import activations, forward_default, get_activation
|
||||
|
||||
from ..external.next_vit.classification.nextvit import *
|
||||
|
||||
|
||||
def forward_next_vit(pretrained, x):
|
||||
return forward_default(pretrained, x, "forward")
|
||||
|
||||
|
||||
def _make_next_vit_backbone(
|
||||
model,
|
||||
hooks=[2, 6, 36, 39],
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_next_vit_large_6m(hooks=None):
|
||||
model = timm.create_model("nextvit_large")
|
||||
|
||||
hooks = [2, 6, 36, 39] if hooks == None else hooks
|
||||
return _make_next_vit_backbone(
|
||||
model,
|
||||
hooks=hooks,
|
||||
)
|
||||
Reference in New Issue
Block a user