DIS/IS-Net/Inference.py

53 lines
2.2 KiB
Python
Raw Normal View History

2022-08-16 12:10:30 +02:00
import os
import time
import numpy as np
from skimage import io
import time
from glob import glob
from tqdm import tqdm
import torch, gc
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from models import *
if __name__ == "__main__":
2022-08-21 20:33:38 +02:00
dataset_path="/home/jiayi.zhu/DIS_projects/demo" #Your dataset path
model_path="/home/jiayi.zhu/DIS_projects/saved_models/IS-Net-New/gpu_itr_654000_traLoss_0.1955_traTarLoss_0.0186_valLoss_0.9414_valTarLoss_0.1418_maxF1_0.9117_mae_0.0388_time_0.030502.pth" #Your model path
result_path="/home/jiayi.zhu/DIS_projects/demo_result" #The folder path that you want to save the results
2022-08-16 12:10:30 +02:00
input_size=[1024,1024]
net=ISNetDIS()
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
2022-08-20 17:10:44 +02:00
net.eval()
2022-08-21 20:33:38 +02:00
im_list = glob(dataset_path+"/*.jpg")+glob(dataset_path+"/*.JPG")+glob(dataset_path+"/*.jpeg")+glob(dataset_path+"/*.JPEG")+glob(dataset_path+"/*.png")+glob(dataset_path+"/*.PNG")+glob(dataset_path+"/*.bmp")+glob(dataset_path+"/*.BMP")+glob(dataset_path+"/*.tiff")+glob(dataset_path+"/*.TIFF")
2022-08-16 12:10:30 +02:00
for i, im_path in tqdm(enumerate(im_list), total=len(im_list)):
print("im_path: ", im_path)
im = io.imread(im_path)
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
im_shp=im.shape[0:2]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
2022-08-20 17:10:44 +02:00
im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor,255.0)
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
2022-08-16 12:10:30 +02:00
if torch.cuda.is_available():
image=image.cuda()
result=net(image)
result=torch.squeeze(F.upsample(result[0][0],im_shp,mode='bilinear'),0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
im_name=im_path.split('/')[-1].split('.')[0]
io.imsave(os.path.join(result_path,im_name+".png"),(result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8))