Files
Magic123/midas/base_model.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

17 lines
367 B
Python

import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)