2022-08-16 12:10:30 +02:00
import os
import time
import numpy as np
from skimage import io
import time
from glob import glob
from tqdm import tqdm
import torch , gc
import torch . nn as nn
from torch . autograd import Variable
import torch . optim as optim
import torch . nn . functional as F
from torchvision . transforms . functional import normalize
from models import *
if __name__ == " __main__ " :
2022-08-21 20:33:38 +02:00
dataset_path = " /home/jiayi.zhu/DIS_projects/demo " #Your dataset path
model_path = " /home/jiayi.zhu/DIS_projects/saved_models/IS-Net-New/gpu_itr_654000_traLoss_0.1955_traTarLoss_0.0186_valLoss_0.9414_valTarLoss_0.1418_maxF1_0.9117_mae_0.0388_time_0.030502.pth " #Your model path
result_path = " /home/jiayi.zhu/DIS_projects/demo_result " #The folder path that you want to save the results
2022-08-16 12:10:30 +02:00
input_size = [ 1024 , 1024 ]
net = ISNetDIS ( )
if torch . cuda . is_available ( ) :
net . load_state_dict ( torch . load ( model_path ) )
net = net . cuda ( )
else :
net . load_state_dict ( torch . load ( model_path , map_location = " cpu " ) )
2022-08-20 17:10:44 +02:00
net . eval ( )
2022-08-21 20:33:38 +02:00
im_list = glob ( dataset_path + " /*.jpg " ) + glob ( dataset_path + " /*.JPG " ) + glob ( dataset_path + " /*.jpeg " ) + glob ( dataset_path + " /*.JPEG " ) + glob ( dataset_path + " /*.png " ) + glob ( dataset_path + " /*.PNG " ) + glob ( dataset_path + " /*.bmp " ) + glob ( dataset_path + " /*.BMP " ) + glob ( dataset_path + " /*.tiff " ) + glob ( dataset_path + " /*.TIFF " )
2022-08-16 12:10:30 +02:00
for i , im_path in tqdm ( enumerate ( im_list ) , total = len ( im_list ) ) :
print ( " im_path: " , im_path )
im = io . imread ( im_path )
if len ( im . shape ) < 3 :
im = im [ : , : , np . newaxis ]
im_shp = im . shape [ 0 : 2 ]
im_tensor = torch . tensor ( im , dtype = torch . float32 ) . permute ( 2 , 0 , 1 )
2022-08-20 17:10:44 +02:00
im_tensor = F . upsample ( torch . unsqueeze ( im_tensor , 0 ) , input_size , mode = " bilinear " ) . type ( torch . uint8 )
image = torch . divide ( im_tensor , 255.0 )
image = normalize ( image , [ 0.5 , 0.5 , 0.5 ] , [ 1.0 , 1.0 , 1.0 ] )
2022-08-16 12:10:30 +02:00
if torch . cuda . is_available ( ) :
image = image . cuda ( )
result = net ( image )
result = torch . squeeze ( F . upsample ( result [ 0 ] [ 0 ] , im_shp , mode = ' bilinear ' ) , 0 )
ma = torch . max ( result )
mi = torch . min ( result )
result = ( result - mi ) / ( ma - mi )
im_name = im_path . split ( ' / ' ) [ - 1 ] . split ( ' . ' ) [ 0 ]
io . imsave ( os . path . join ( result_path , im_name + " .png " ) , ( result * 255 ) . permute ( 1 , 2 , 0 ) . cpu ( ) . data . numpy ( ) . astype ( np . uint8 ) )