official release of our isnet and dis5k
BIN
DIS5K-Dataset-Terms-of-Use.pdf
Normal file
BIN
IS-Net/__pycache__/basics.cpython-37.pyc
Normal file
BIN
IS-Net/__pycache__/data_loader_cache.cpython-37.pyc
Normal file
74
IS-Net/basics.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import os
|
||||||
|
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
||||||
|
from skimage import io, transform
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms, utils
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import glob
|
||||||
|
|
||||||
|
def mae_torch(pred,gt):
|
||||||
|
|
||||||
|
h,w = gt.shape[0:2]
|
||||||
|
sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
|
||||||
|
maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4)
|
||||||
|
|
||||||
|
return maeError
|
||||||
|
|
||||||
|
def f1score_torch(pd,gt):
|
||||||
|
|
||||||
|
# print(gt.shape)
|
||||||
|
gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels
|
||||||
|
|
||||||
|
pp = pd[gt>128]
|
||||||
|
nn = pd[gt<=128]
|
||||||
|
|
||||||
|
pp_hist =torch.histc(pp,bins=255,min=0,max=255)
|
||||||
|
nn_hist = torch.histc(nn,bins=255,min=0,max=255)
|
||||||
|
|
||||||
|
|
||||||
|
pp_hist_flip = torch.flipud(pp_hist)
|
||||||
|
nn_hist_flip = torch.flipud(nn_hist)
|
||||||
|
|
||||||
|
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
|
||||||
|
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
|
||||||
|
|
||||||
|
precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
|
||||||
|
recall = (pp_hist_flip_cum)/(gtNum + 1e-4)
|
||||||
|
f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4)
|
||||||
|
|
||||||
|
return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
|
||||||
|
|
||||||
|
import time
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
if(len(gt.shape)>2):
|
||||||
|
gt = gt[:,:,0]
|
||||||
|
|
||||||
|
pre, rec, f1 = f1score_torch(pred,gt)
|
||||||
|
mae = mae_torch(pred,gt)
|
||||||
|
|
||||||
|
|
||||||
|
# hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ###
|
||||||
|
if(hypar["valid_out_dir"]!=""):
|
||||||
|
if(not os.path.exists(hypar["valid_out_dir"])):
|
||||||
|
os.mkdir(hypar["valid_out_dir"])
|
||||||
|
dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx])
|
||||||
|
if(not os.path.exists(dataset_folder)):
|
||||||
|
os.mkdir(dataset_folder)
|
||||||
|
io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8))
|
||||||
|
print(valid_dataset.dataset["im_name"][idx]+".png")
|
||||||
|
print("time for evaluation : ", time.time()-tic)
|
||||||
|
|
||||||
|
return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy()
|
380
IS-Net/data_loader_cache.py
Normal file
@ -0,0 +1,380 @@
|
|||||||
|
## data loader
|
||||||
|
## Ackownledgement:
|
||||||
|
## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
|
||||||
|
## for his helps in implementing cache machanism of our DIS dataloader.
|
||||||
|
from __future__ import print_function, division
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from copy import deepcopy
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
from skimage import io
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms, utils
|
||||||
|
from torchvision.transforms.functional import normalize
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
#### --------------------- DIS dataloader cache ---------------------####
|
||||||
|
|
||||||
|
def get_im_gt_name_dict(datasets, flag='valid'):
|
||||||
|
print("------------------------------", flag, "--------------------------------")
|
||||||
|
name_im_gt_list = []
|
||||||
|
for i in range(len(datasets)):
|
||||||
|
print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
|
||||||
|
tmp_im_list, tmp_gt_list = [], []
|
||||||
|
tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
|
||||||
|
|
||||||
|
# img_name_dict[im_dirs[i][0]] = tmp_im_list
|
||||||
|
print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
|
||||||
|
|
||||||
|
if(datasets[i]["gt_dir"]==""):
|
||||||
|
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
|
||||||
|
else:
|
||||||
|
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
|
||||||
|
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
|
||||||
|
print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
|
||||||
|
|
||||||
|
|
||||||
|
if flag=="train": ## combine multiple training sets into one dataset
|
||||||
|
if len(name_im_gt_list)==0:
|
||||||
|
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
|
||||||
|
"im_path":tmp_im_list,
|
||||||
|
"gt_path":tmp_gt_list,
|
||||||
|
"im_ext":datasets[i]["im_ext"],
|
||||||
|
"gt_ext":datasets[i]["gt_ext"],
|
||||||
|
"cache_dir":datasets[i]["cache_dir"]})
|
||||||
|
else:
|
||||||
|
name_im_gt_list[0]["dataset_name"] = name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
|
||||||
|
name_im_gt_list[0]["im_path"] = name_im_gt_list[0]["im_path"] + tmp_im_list
|
||||||
|
name_im_gt_list[0]["gt_path"] = name_im_gt_list[0]["gt_path"] + tmp_gt_list
|
||||||
|
if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
|
||||||
|
print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
|
||||||
|
exit()
|
||||||
|
name_im_gt_list[0]["im_ext"] = ".jpg"
|
||||||
|
name_im_gt_list[0]["gt_ext"] = ".png"
|
||||||
|
name_im_gt_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_list[0]["dataset_name"]
|
||||||
|
else: ## keep different validation or inference datasets as separate ones
|
||||||
|
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
|
||||||
|
"im_path":tmp_im_list,
|
||||||
|
"gt_path":tmp_gt_list,
|
||||||
|
"im_ext":datasets[i]["im_ext"],
|
||||||
|
"gt_ext":datasets[i]["gt_ext"],
|
||||||
|
"cache_dir":datasets[i]["cache_dir"]})
|
||||||
|
|
||||||
|
return name_im_gt_list
|
||||||
|
|
||||||
|
def create_dataloaders(name_im_gt_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False):
|
||||||
|
## model="train": return one dataloader for training
|
||||||
|
## model="valid": return a list of dataloaders for validation or testing
|
||||||
|
|
||||||
|
gos_dataloaders = []
|
||||||
|
gos_datasets = []
|
||||||
|
# if(mode=="train"):
|
||||||
|
if(len(name_im_gt_list)==0):
|
||||||
|
return gos_dataloaders, gos_datasets
|
||||||
|
|
||||||
|
num_workers_ = 1
|
||||||
|
if(batch_size>1):
|
||||||
|
num_workers_ = 2
|
||||||
|
if(batch_size>4):
|
||||||
|
num_workers_ = 4
|
||||||
|
if(batch_size>8):
|
||||||
|
num_workers_ = 8
|
||||||
|
|
||||||
|
for i in range(0,len(name_im_gt_list)):
|
||||||
|
gos_dataset = GOSDatasetCache([name_im_gt_list[i]],
|
||||||
|
cache_size = cache_size,
|
||||||
|
cache_path = name_im_gt_list[i]["cache_dir"],
|
||||||
|
cache_boost = cache_boost,
|
||||||
|
transform = transforms.Compose(my_transforms))
|
||||||
|
gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
|
||||||
|
gos_datasets.append(gos_dataset)
|
||||||
|
|
||||||
|
return gos_dataloaders, gos_datasets
|
||||||
|
|
||||||
|
def im_reader(im_path):
|
||||||
|
return io.imread(im_path)
|
||||||
|
|
||||||
|
def im_preprocess(im,size):
|
||||||
|
if len(im.shape) < 3:
|
||||||
|
im = im[:, :, np.newaxis]
|
||||||
|
if im.shape[2] == 1:
|
||||||
|
im = np.repeat(im, 3, axis=2)
|
||||||
|
im_tensor = torch.tensor(im, dtype=torch.float32)
|
||||||
|
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
||||||
|
if(len(size)<2):
|
||||||
|
return im_tensor, im.shape[0:2]
|
||||||
|
else:
|
||||||
|
im_tensor = torch.unsqueeze(im_tensor,0)
|
||||||
|
im_tensor = F.upsample(im_tensor, size, mode="bilinear")
|
||||||
|
im_tensor = torch.squeeze(im_tensor,0)
|
||||||
|
|
||||||
|
return im_tensor.type(torch.uint8), im.shape[0:2]
|
||||||
|
|
||||||
|
def gt_preprocess(gt,size):
|
||||||
|
if len(gt.shape) > 2:
|
||||||
|
gt = gt[:, :, 0]
|
||||||
|
|
||||||
|
gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
|
||||||
|
|
||||||
|
if(len(size)<2):
|
||||||
|
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
||||||
|
else:
|
||||||
|
gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
|
||||||
|
gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
|
||||||
|
gt_tensor = torch.squeeze(gt_tensor,0)
|
||||||
|
|
||||||
|
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
||||||
|
# return gt_tensor, gt.shape[0:2]
|
||||||
|
|
||||||
|
class GOSRandomHFlip(object):
|
||||||
|
def __init__(self,prob=0.5):
|
||||||
|
self.prob = prob
|
||||||
|
def __call__(self,sample):
|
||||||
|
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
||||||
|
|
||||||
|
# random horizontal flip
|
||||||
|
if random.random() >= self.prob:
|
||||||
|
image = torch.flip(image,dims=[2])
|
||||||
|
label = torch.flip(label,dims=[2])
|
||||||
|
|
||||||
|
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
||||||
|
|
||||||
|
class GOSResize(object):
|
||||||
|
def __init__(self,size=[320,320]):
|
||||||
|
self.size = size
|
||||||
|
def __call__(self,sample):
|
||||||
|
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
||||||
|
|
||||||
|
# import time
|
||||||
|
# start = time.time()
|
||||||
|
|
||||||
|
image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
|
||||||
|
label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
|
||||||
|
|
||||||
|
# print("time for resize: ", time.time()-start)
|
||||||
|
|
||||||
|
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
||||||
|
|
||||||
|
class GOSRandomCrop(object):
|
||||||
|
def __init__(self,size=[288,288]):
|
||||||
|
self.size = size
|
||||||
|
def __call__(self,sample):
|
||||||
|
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
||||||
|
|
||||||
|
h, w = image.shape[1:]
|
||||||
|
new_h, new_w = self.size
|
||||||
|
|
||||||
|
top = np.random.randint(0, h - new_h)
|
||||||
|
left = np.random.randint(0, w - new_w)
|
||||||
|
|
||||||
|
image = image[:,top:top+new_h,left:left+new_w]
|
||||||
|
label = label[:,top:top+new_h,left:left+new_w]
|
||||||
|
|
||||||
|
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
||||||
|
|
||||||
|
|
||||||
|
class GOSNormalize(object):
|
||||||
|
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def __call__(self,sample):
|
||||||
|
|
||||||
|
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
||||||
|
image = normalize(image,self.mean,self.std)
|
||||||
|
|
||||||
|
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
||||||
|
|
||||||
|
|
||||||
|
class GOSDatasetCache(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, name_im_gt_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None):
|
||||||
|
|
||||||
|
|
||||||
|
self.cache_size = cache_size
|
||||||
|
self.cache_path = cache_path
|
||||||
|
self.cache_file_name = cache_file_name
|
||||||
|
self.cache_boost_name = ""
|
||||||
|
|
||||||
|
self.cache_boost = cache_boost
|
||||||
|
# self.ims_npy = None
|
||||||
|
# self.gts_npy = None
|
||||||
|
|
||||||
|
## cache all the images and ground truth into a single pytorch tensor
|
||||||
|
self.ims_pt = None
|
||||||
|
self.gts_pt = None
|
||||||
|
|
||||||
|
## we will cache the npy as well regardless of the cache_boost
|
||||||
|
# if(self.cache_boost):
|
||||||
|
self.cache_boost_name = cache_file_name.split('.json')[0]
|
||||||
|
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
self.dataset = {}
|
||||||
|
|
||||||
|
## combine different datasets into one
|
||||||
|
dataset_names = []
|
||||||
|
dt_name_list = [] # dataset name per image
|
||||||
|
im_name_list = [] # image name
|
||||||
|
im_path_list = [] # im path
|
||||||
|
gt_path_list = [] # gt path
|
||||||
|
im_ext_list = [] # im ext
|
||||||
|
gt_ext_list = [] # gt ext
|
||||||
|
for i in range(0,len(name_im_gt_list)):
|
||||||
|
dataset_names.append(name_im_gt_list[i]["dataset_name"])
|
||||||
|
# dataset name repeated based on the number of images in this dataset
|
||||||
|
dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]])
|
||||||
|
im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]])
|
||||||
|
im_path_list.extend(name_im_gt_list[i]["im_path"])
|
||||||
|
gt_path_list.extend(name_im_gt_list[i]["gt_path"])
|
||||||
|
im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]])
|
||||||
|
gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]])
|
||||||
|
|
||||||
|
|
||||||
|
self.dataset["data_name"] = dt_name_list
|
||||||
|
self.dataset["im_name"] = im_name_list
|
||||||
|
self.dataset["im_path"] = im_path_list
|
||||||
|
self.dataset["ori_im_path"] = deepcopy(im_path_list)
|
||||||
|
self.dataset["gt_path"] = gt_path_list
|
||||||
|
self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
|
||||||
|
self.dataset["im_shp"] = []
|
||||||
|
self.dataset["gt_shp"] = []
|
||||||
|
self.dataset["im_ext"] = im_ext_list
|
||||||
|
self.dataset["gt_ext"] = gt_ext_list
|
||||||
|
|
||||||
|
|
||||||
|
self.dataset["ims_pt_dir"] = ""
|
||||||
|
self.dataset["gts_pt_dir"] = ""
|
||||||
|
|
||||||
|
self.dataset = self.manage_cache(dataset_names)
|
||||||
|
|
||||||
|
def manage_cache(self,dataset_names):
|
||||||
|
if not os.path.exists(self.cache_path): # create the folder for cache
|
||||||
|
os.mkdir(self.cache_path)
|
||||||
|
cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
|
||||||
|
if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
|
||||||
|
return self.cache(cache_folder)
|
||||||
|
return self.load_cache(cache_folder)
|
||||||
|
|
||||||
|
def cache(self,cache_folder):
|
||||||
|
os.mkdir(cache_folder)
|
||||||
|
cached_dataset = deepcopy(self.dataset)
|
||||||
|
|
||||||
|
# ims_list = []
|
||||||
|
# gts_list = []
|
||||||
|
ims_pt_list = []
|
||||||
|
gts_pt_list = []
|
||||||
|
for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
|
||||||
|
|
||||||
|
|
||||||
|
im_id = cached_dataset["im_name"][i]
|
||||||
|
|
||||||
|
im = im_reader(im_path)
|
||||||
|
im, im_shp = im_preprocess(im,self.cache_size)
|
||||||
|
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
|
||||||
|
torch.save(im,im_cache_file)
|
||||||
|
|
||||||
|
cached_dataset["im_path"][i] = im_cache_file
|
||||||
|
if(self.cache_boost):
|
||||||
|
ims_pt_list.append(torch.unsqueeze(im,0))
|
||||||
|
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
|
||||||
|
|
||||||
|
|
||||||
|
gt = im_reader(self.dataset["gt_path"][i])
|
||||||
|
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
||||||
|
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
||||||
|
torch.save(gt,gt_cache_file)
|
||||||
|
cached_dataset["gt_path"][i] = gt_cache_file
|
||||||
|
if(self.cache_boost):
|
||||||
|
gts_pt_list.append(torch.unsqueeze(gt,0))
|
||||||
|
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
||||||
|
|
||||||
|
# im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
|
||||||
|
# torch.save(gt_shp, shp_cache_file)
|
||||||
|
cached_dataset["im_shp"].append(im_shp)
|
||||||
|
# self.dataset["im_shp"].append(im_shp)
|
||||||
|
|
||||||
|
# shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
|
||||||
|
# torch.save(gt_shp, shp_cache_file)
|
||||||
|
cached_dataset["gt_shp"].append(gt_shp)
|
||||||
|
# self.dataset["gt_shp"].append(gt_shp)
|
||||||
|
|
||||||
|
if(self.cache_boost):
|
||||||
|
cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
|
||||||
|
cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
|
||||||
|
self.ims_pt = torch.cat(ims_pt_list,dim=0)
|
||||||
|
self.gts_pt = torch.cat(gts_pt_list,dim=0)
|
||||||
|
torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
|
||||||
|
torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
|
||||||
|
json.dump(cached_dataset, json_file)
|
||||||
|
json_file.close()
|
||||||
|
except Exception:
|
||||||
|
raise FileNotFoundError("Cannot create JSON")
|
||||||
|
return cached_dataset
|
||||||
|
|
||||||
|
def load_cache(self, cache_folder):
|
||||||
|
json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
|
||||||
|
dataset = json.load(json_file)
|
||||||
|
json_file.close()
|
||||||
|
## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
|
||||||
|
## otherwise the pytorch tensor will be loaded
|
||||||
|
if(self.cache_boost):
|
||||||
|
# self.ims_npy = np.load(dataset["ims_npy_dir"])
|
||||||
|
# self.gts_npy = np.load(dataset["gts_npy_dir"])
|
||||||
|
self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
|
||||||
|
self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset["im_path"])
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
|
||||||
|
im = None
|
||||||
|
gt = None
|
||||||
|
if(self.cache_boost and self.ims_pt is not None):
|
||||||
|
|
||||||
|
# start = time.time()
|
||||||
|
im = self.ims_pt[idx]#.type(torch.float32)
|
||||||
|
gt = self.gts_pt[idx]#.type(torch.float32)
|
||||||
|
# print(idx, 'time for pt loading: ', time.time()-start)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# import time
|
||||||
|
# start = time.time()
|
||||||
|
# print("tensor***")
|
||||||
|
im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
|
||||||
|
im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
|
||||||
|
gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
|
||||||
|
gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
|
||||||
|
# print(idx,'time for tensor loading: ', time.time()-start)
|
||||||
|
|
||||||
|
|
||||||
|
im_shp = self.dataset["im_shp"][idx]
|
||||||
|
# print("time for loading im and gt: ", time.time()-start)
|
||||||
|
|
||||||
|
# start_time = time.time()
|
||||||
|
im = torch.divide(im,255.0)
|
||||||
|
gt = torch.divide(gt,255.0)
|
||||||
|
# print(idx, 'time for normalize torch divide: ', time.time()-start_time)
|
||||||
|
|
||||||
|
sample = {
|
||||||
|
"imidx": torch.from_numpy(np.array(idx)),
|
||||||
|
"image": im,
|
||||||
|
"label": gt,
|
||||||
|
"shape": torch.from_numpy(np.array(im_shp)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
sample = self.transform(sample)
|
||||||
|
|
||||||
|
return sample
|
188
IS-Net/hce_metric_main.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
## hce_metric.py
|
||||||
|
import numpy as np
|
||||||
|
from skimage import io
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import cv2 as cv
|
||||||
|
from skimage.morphology import skeletonize
|
||||||
|
from skimage.morphology import erosion, dilation, disk
|
||||||
|
from skimage.measure import label
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from tqdm import tqdm
|
||||||
|
from glob import glob
|
||||||
|
import pickle as pkl
|
||||||
|
|
||||||
|
def filter_bdy_cond(bdy_, mask, cond):
|
||||||
|
|
||||||
|
cond = cv.dilate(cond.astype(np.uint8),disk(1))
|
||||||
|
labels = label(mask) # find the connected regions
|
||||||
|
lbls = np.unique(labels) # the indices of the connected regions
|
||||||
|
indep = np.ones(lbls.shape[0]) # the label of each connected regions
|
||||||
|
indep[0] = 0 # 0 indicate the background region
|
||||||
|
|
||||||
|
boundaries = []
|
||||||
|
h,w = cond.shape[0:2]
|
||||||
|
ind_map = np.zeros((h,w))
|
||||||
|
indep_cnt = 0
|
||||||
|
|
||||||
|
for i in range(0,len(bdy_)):
|
||||||
|
tmp_bdies = []
|
||||||
|
tmp_bdy = []
|
||||||
|
for j in range(0,bdy_[i].shape[0]):
|
||||||
|
r, c = bdy_[i][j,0,1],bdy_[i][j,0,0]
|
||||||
|
|
||||||
|
if(np.sum(cond[r,c])==0 or ind_map[r,c]!=0):
|
||||||
|
if(len(tmp_bdy)>0):
|
||||||
|
tmp_bdies.append(tmp_bdy)
|
||||||
|
tmp_bdy = []
|
||||||
|
continue
|
||||||
|
tmp_bdy.append([c,r])
|
||||||
|
ind_map[r,c] = ind_map[r,c] + 1
|
||||||
|
indep[labels[r,c]] = 0 # indicates part of the boundary of this region needs human correction
|
||||||
|
if(len(tmp_bdy)>0):
|
||||||
|
tmp_bdies.append(tmp_bdy)
|
||||||
|
|
||||||
|
# check if the first and the last boundaries are connected
|
||||||
|
# if yes, invert the first boundary and attach it after the last boundary
|
||||||
|
if(len(tmp_bdies)>1):
|
||||||
|
first_x, first_y = tmp_bdies[0][0]
|
||||||
|
last_x, last_y = tmp_bdies[-1][-1]
|
||||||
|
if((abs(first_x-last_x)==1 and first_y==last_y) or
|
||||||
|
(first_x==last_x and abs(first_y-last_y)==1) or
|
||||||
|
(abs(first_x-last_x)==1 and abs(first_y-last_y)==1)
|
||||||
|
):
|
||||||
|
tmp_bdies[-1].extend(tmp_bdies[0][::-1])
|
||||||
|
del tmp_bdies[0]
|
||||||
|
|
||||||
|
for k in range(0,len(tmp_bdies)):
|
||||||
|
tmp_bdies[k] = np.array(tmp_bdies[k])[:,np.newaxis,:]
|
||||||
|
if(len(tmp_bdies)>0):
|
||||||
|
boundaries.extend(tmp_bdies)
|
||||||
|
|
||||||
|
return boundaries, np.sum(indep)
|
||||||
|
|
||||||
|
# this function approximate each boundary by DP algorithm
|
||||||
|
# https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm
|
||||||
|
def approximate_RDP(boundaries,epsilon=1.0):
|
||||||
|
|
||||||
|
boundaries_ = []
|
||||||
|
boundaries_len_ = []
|
||||||
|
pixel_cnt_ = 0
|
||||||
|
|
||||||
|
# polygon approximate of each boundary
|
||||||
|
for i in range(0,len(boundaries)):
|
||||||
|
boundaries_.append(cv.approxPolyDP(boundaries[i],epsilon,False))
|
||||||
|
|
||||||
|
# count the control points number of each boundary and the total control points number of all the boundaries
|
||||||
|
for i in range(0,len(boundaries_)):
|
||||||
|
boundaries_len_.append(len(boundaries_[i]))
|
||||||
|
pixel_cnt_ = pixel_cnt_ + len(boundaries_[i])
|
||||||
|
|
||||||
|
return boundaries_, boundaries_len_, pixel_cnt_
|
||||||
|
|
||||||
|
|
||||||
|
def relax_HCE(gt, rs, gt_ske, relax=5, epsilon=2.0):
|
||||||
|
# print("max(gt_ske): ", np.amax(gt_ske))
|
||||||
|
# gt_ske = gt_ske>128
|
||||||
|
# print("max(gt_ske): ", np.amax(gt_ske))
|
||||||
|
|
||||||
|
# Binarize gt
|
||||||
|
if(len(gt.shape)>2):
|
||||||
|
gt = gt[:,:,0]
|
||||||
|
|
||||||
|
epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0
|
||||||
|
gt = (gt>epsilon_gt).astype(np.uint8)
|
||||||
|
|
||||||
|
# Binarize rs
|
||||||
|
if(len(rs.shape)>2):
|
||||||
|
rs = rs[:,:,0]
|
||||||
|
epsilon_rs = 128#(np.amin(rs)+np.amax(rs))/2.0
|
||||||
|
rs = (rs>epsilon_rs).astype(np.uint8)
|
||||||
|
|
||||||
|
Union = np.logical_or(gt,rs)
|
||||||
|
TP = np.logical_and(gt,rs)
|
||||||
|
FP = rs - TP
|
||||||
|
FN = gt - TP
|
||||||
|
|
||||||
|
# relax the Union of gt and rs
|
||||||
|
Union_erode = Union.copy()
|
||||||
|
Union_erode = cv.erode(Union_erode.astype(np.uint8),disk(1),iterations=relax)
|
||||||
|
|
||||||
|
# --- get the relaxed False Positive regions for computing the human efforts in correcting them ---
|
||||||
|
FP_ = np.logical_and(FP,Union_erode) # get the relaxed FP
|
||||||
|
for i in range(0,relax):
|
||||||
|
FP_ = cv.dilate(FP_.astype(np.uint8),disk(1))
|
||||||
|
FP_ = np.logical_and(FP_, 1-np.logical_or(TP,FN))
|
||||||
|
FP_ = np.logical_and(FP, FP_)
|
||||||
|
|
||||||
|
# --- get the relaxed False Negative regions for computing the human efforts in correcting them ---
|
||||||
|
FN_ = np.logical_and(FN,Union_erode) # preserve the structural components of FN
|
||||||
|
## recover the FN, where pixels are not close to the TP borders
|
||||||
|
for i in range(0,relax):
|
||||||
|
FN_ = cv.dilate(FN_.astype(np.uint8),disk(1))
|
||||||
|
FN_ = np.logical_and(FN_,1-np.logical_or(TP,FP))
|
||||||
|
FN_ = np.logical_and(FN,FN_)
|
||||||
|
FN_ = np.logical_or(FN_, np.logical_xor(gt_ske,np.logical_and(TP,gt_ske))) # preserve the structural components of FN
|
||||||
|
|
||||||
|
## 2. =============Find exact polygon control points and independent regions==============
|
||||||
|
## find contours from FP_
|
||||||
|
ctrs_FP, hier_FP = cv.findContours(FP_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
||||||
|
## find control points and independent regions for human correction
|
||||||
|
bdies_FP, indep_cnt_FP = filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_))
|
||||||
|
## find contours from FN_
|
||||||
|
ctrs_FN, hier_FN = cv.findContours(FN_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
||||||
|
## find control points and independent regions for human correction
|
||||||
|
bdies_FN, indep_cnt_FN = filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP,FP_),FN_))
|
||||||
|
|
||||||
|
poly_FP, poly_FP_len, poly_FP_point_cnt = approximate_RDP(bdies_FP,epsilon=epsilon)
|
||||||
|
poly_FN, poly_FN_len, poly_FN_point_cnt = approximate_RDP(bdies_FN,epsilon=epsilon)
|
||||||
|
|
||||||
|
return poly_FP_point_cnt, indep_cnt_FP, poly_FN_point_cnt, indep_cnt_FN
|
||||||
|
|
||||||
|
def compute_hce(pred_root,gt_root,gt_ske_root):
|
||||||
|
|
||||||
|
gt_name_list = glob(pred_root+'/*.png')
|
||||||
|
gt_name_list = sorted([x.split('/')[-1] for x in gt_name_list])
|
||||||
|
|
||||||
|
hces = []
|
||||||
|
for gt_name in tqdm(gt_name_list, total=len(gt_name_list)):
|
||||||
|
gt_path = os.path.join(gt_root, gt_name)
|
||||||
|
pred_path = os.path.join(pred_root, gt_name)
|
||||||
|
|
||||||
|
gt = cv.imread(gt_path, cv.IMREAD_GRAYSCALE)
|
||||||
|
pred = cv.imread(pred_path, cv.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
ske_path = os.path.join(gt_ske_root,gt_name)
|
||||||
|
if os.path.exists(ske_path):
|
||||||
|
ske = cv.imread(ske_path,cv.IMREAD_GRAYSCALE)
|
||||||
|
ske = ske>128
|
||||||
|
else:
|
||||||
|
ske = skeletonize(gt>128)
|
||||||
|
|
||||||
|
FP_points, FP_indep, FN_points, FN_indep = relax_HCE(gt, pred,ske)
|
||||||
|
print(gt_path.split('/')[-1],FP_points, FP_indep, FN_points, FN_indep)
|
||||||
|
hces.append([FP_points, FP_indep, FN_points, FN_indep, FP_points+FP_indep+FN_points+FN_indep])
|
||||||
|
|
||||||
|
hce_metric ={'names': gt_name_list,
|
||||||
|
'hces': hces}
|
||||||
|
|
||||||
|
|
||||||
|
file_metric = open(pred_root+'/hce_metric.pkl','wb')
|
||||||
|
pkl.dump(hce_metric,file_metric)
|
||||||
|
# file_metrics.write(cmn_metrics)
|
||||||
|
file_metric.close()
|
||||||
|
|
||||||
|
return np.mean(np.array(hces)[:,-1])
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
gt_root = "../DIS5K/DIS-VD/gt"
|
||||||
|
gt_ske_root = ""
|
||||||
|
pred_root = "../Results/isnet(ours)/DIS-VD"
|
||||||
|
|
||||||
|
print("The average HCE metric: ", compute_hce(pred_root,gt_root,gt_ske_root))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
1
IS-Net/models/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from models.isnet import ISNetGTEncoder, ISNetDIS
|
BIN
IS-Net/models/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
IS-Net/models/__pycache__/u2netfast.cpython-37.pyc
Normal file
614
IS-Net/models/isnet.py
Normal file
@ -0,0 +1,614 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision import models
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
bce_loss = nn.BCELoss(size_average=True)
|
||||||
|
def muti_loss_fusion(preds, target):
|
||||||
|
loss0 = 0.0
|
||||||
|
loss = 0.0
|
||||||
|
|
||||||
|
for i in range(0,len(preds)):
|
||||||
|
# print("i: ", i, preds[i].shape)
|
||||||
|
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
|
||||||
|
# tmp_target = _upsample_like(target,preds[i])
|
||||||
|
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
|
||||||
|
loss = loss + bce_loss(preds[i],tmp_target)
|
||||||
|
else:
|
||||||
|
loss = loss + bce_loss(preds[i],target)
|
||||||
|
if(i==0):
|
||||||
|
loss0 = loss
|
||||||
|
return loss0, loss
|
||||||
|
|
||||||
|
fea_loss = nn.MSELoss(size_average=True)
|
||||||
|
kl_loss = nn.KLDivLoss(size_average=True)
|
||||||
|
l1_loss = nn.L1Loss(size_average=True)
|
||||||
|
smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
|
||||||
|
def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
|
||||||
|
loss0 = 0.0
|
||||||
|
loss = 0.0
|
||||||
|
|
||||||
|
for i in range(0,len(preds)):
|
||||||
|
# print("i: ", i, preds[i].shape)
|
||||||
|
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
|
||||||
|
# tmp_target = _upsample_like(target,preds[i])
|
||||||
|
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
|
||||||
|
loss = loss + bce_loss(preds[i],tmp_target)
|
||||||
|
else:
|
||||||
|
loss = loss + bce_loss(preds[i],target)
|
||||||
|
if(i==0):
|
||||||
|
loss0 = loss
|
||||||
|
|
||||||
|
for i in range(0,len(dfs)):
|
||||||
|
if(mode=='MSE'):
|
||||||
|
loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
|
||||||
|
# print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
|
||||||
|
elif(mode=='KL'):
|
||||||
|
loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
|
||||||
|
# print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
|
||||||
|
elif(mode=='MAE'):
|
||||||
|
loss = loss + l1_loss(dfs[i],fs[i])
|
||||||
|
# print("ls_loss: ", l1_loss(dfs[i],fs[i]))
|
||||||
|
elif(mode=='SmoothL1'):
|
||||||
|
loss = loss + smooth_l1_loss(dfs[i],fs[i])
|
||||||
|
# print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
|
||||||
|
|
||||||
|
return loss0, loss
|
||||||
|
|
||||||
|
class REBNCONV(nn.Module):
|
||||||
|
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
||||||
|
super(REBNCONV,self).__init__()
|
||||||
|
|
||||||
|
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
|
||||||
|
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||||
|
self.relu_s1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||||
|
|
||||||
|
return xout
|
||||||
|
|
||||||
|
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||||
|
def _upsample_like(src,tar):
|
||||||
|
|
||||||
|
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
### RSU-7 ###
|
||||||
|
class RSU7(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
||||||
|
super(RSU7,self).__init__()
|
||||||
|
|
||||||
|
self.in_ch = in_ch
|
||||||
|
self.mid_ch = mid_ch
|
||||||
|
self.out_ch = out_ch
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||||
|
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||||
|
|
||||||
|
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx = self.pool1(hx1)
|
||||||
|
|
||||||
|
hx2 = self.rebnconv2(hx)
|
||||||
|
hx = self.pool2(hx2)
|
||||||
|
|
||||||
|
hx3 = self.rebnconv3(hx)
|
||||||
|
hx = self.pool3(hx3)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx)
|
||||||
|
hx = self.pool4(hx4)
|
||||||
|
|
||||||
|
hx5 = self.rebnconv5(hx)
|
||||||
|
hx = self.pool5(hx5)
|
||||||
|
|
||||||
|
hx6 = self.rebnconv6(hx)
|
||||||
|
|
||||||
|
hx7 = self.rebnconv7(hx6)
|
||||||
|
|
||||||
|
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
|
||||||
|
hx6dup = _upsample_like(hx6d,hx5)
|
||||||
|
|
||||||
|
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
|
||||||
|
hx5dup = _upsample_like(hx5d,hx4)
|
||||||
|
|
||||||
|
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||||
|
hx4dup = _upsample_like(hx4d,hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||||
|
hx3dup = _upsample_like(hx3d,hx2)
|
||||||
|
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||||
|
hx2dup = _upsample_like(hx2d,hx1)
|
||||||
|
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
|
||||||
|
### RSU-6 ###
|
||||||
|
class RSU6(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||||
|
super(RSU6,self).__init__()
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||||
|
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||||
|
|
||||||
|
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx = self.pool1(hx1)
|
||||||
|
|
||||||
|
hx2 = self.rebnconv2(hx)
|
||||||
|
hx = self.pool2(hx2)
|
||||||
|
|
||||||
|
hx3 = self.rebnconv3(hx)
|
||||||
|
hx = self.pool3(hx3)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx)
|
||||||
|
hx = self.pool4(hx4)
|
||||||
|
|
||||||
|
hx5 = self.rebnconv5(hx)
|
||||||
|
|
||||||
|
hx6 = self.rebnconv6(hx5)
|
||||||
|
|
||||||
|
|
||||||
|
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
|
||||||
|
hx5dup = _upsample_like(hx5d,hx4)
|
||||||
|
|
||||||
|
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||||
|
hx4dup = _upsample_like(hx4d,hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||||
|
hx3dup = _upsample_like(hx3d,hx2)
|
||||||
|
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||||
|
hx2dup = _upsample_like(hx2d,hx1)
|
||||||
|
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
### RSU-5 ###
|
||||||
|
class RSU5(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||||
|
super(RSU5,self).__init__()
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||||
|
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||||
|
|
||||||
|
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx = self.pool1(hx1)
|
||||||
|
|
||||||
|
hx2 = self.rebnconv2(hx)
|
||||||
|
hx = self.pool2(hx2)
|
||||||
|
|
||||||
|
hx3 = self.rebnconv3(hx)
|
||||||
|
hx = self.pool3(hx3)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx)
|
||||||
|
|
||||||
|
hx5 = self.rebnconv5(hx4)
|
||||||
|
|
||||||
|
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
|
||||||
|
hx4dup = _upsample_like(hx4d,hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||||
|
hx3dup = _upsample_like(hx3d,hx2)
|
||||||
|
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||||
|
hx2dup = _upsample_like(hx2d,hx1)
|
||||||
|
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
### RSU-4 ###
|
||||||
|
class RSU4(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||||
|
super(RSU4,self).__init__()
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||||
|
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||||
|
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx = self.pool1(hx1)
|
||||||
|
|
||||||
|
hx2 = self.rebnconv2(hx)
|
||||||
|
hx = self.pool2(hx2)
|
||||||
|
|
||||||
|
hx3 = self.rebnconv3(hx)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||||
|
hx3dup = _upsample_like(hx3d,hx2)
|
||||||
|
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||||
|
hx2dup = _upsample_like(hx2d,hx1)
|
||||||
|
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
### RSU-4F ###
|
||||||
|
class RSU4F(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||||
|
super(RSU4F,self).__init__()
|
||||||
|
|
||||||
|
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||||
|
|
||||||
|
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||||
|
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||||
|
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
|
||||||
|
|
||||||
|
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
|
||||||
|
|
||||||
|
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
|
||||||
|
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
|
||||||
|
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.rebnconvin(hx)
|
||||||
|
|
||||||
|
hx1 = self.rebnconv1(hxin)
|
||||||
|
hx2 = self.rebnconv2(hx1)
|
||||||
|
hx3 = self.rebnconv3(hx2)
|
||||||
|
|
||||||
|
hx4 = self.rebnconv4(hx3)
|
||||||
|
|
||||||
|
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||||
|
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
|
||||||
|
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
|
||||||
|
|
||||||
|
return hx1d + hxin
|
||||||
|
|
||||||
|
|
||||||
|
class myrebnconv(nn.Module):
|
||||||
|
def __init__(self, in_ch=3,
|
||||||
|
out_ch=1,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dilation=1,
|
||||||
|
groups=1):
|
||||||
|
super(myrebnconv,self).__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(in_ch,
|
||||||
|
out_ch,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups)
|
||||||
|
self.bn = nn.BatchNorm2d(out_ch)
|
||||||
|
self.rl = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
return self.rl(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class ISNetGTEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,in_ch=1,out_ch=1):
|
||||||
|
super(ISNetGTEncoder,self).__init__()
|
||||||
|
|
||||||
|
self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
||||||
|
|
||||||
|
self.stage1 = RSU7(16,16,64)
|
||||||
|
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage2 = RSU6(64,16,64)
|
||||||
|
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage3 = RSU5(64,32,128)
|
||||||
|
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage4 = RSU4(128,32,256)
|
||||||
|
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage5 = RSU4F(256,64,512)
|
||||||
|
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage6 = RSU4F(512,64,512)
|
||||||
|
|
||||||
|
|
||||||
|
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||||
|
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||||
|
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
||||||
|
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
||||||
|
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||||
|
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||||
|
|
||||||
|
def compute_loss_max(self, preds, targets, fs):
|
||||||
|
|
||||||
|
return muti_loss_fusion_max(preds, targets,fs)
|
||||||
|
|
||||||
|
def compute_loss(self, preds, targets):
|
||||||
|
|
||||||
|
return muti_loss_fusion(preds,targets)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.conv_in(hx)
|
||||||
|
# hx = self.pool_in(hxin)
|
||||||
|
|
||||||
|
#stage 1
|
||||||
|
hx1 = self.stage1(hxin)
|
||||||
|
hx = self.pool12(hx1)
|
||||||
|
|
||||||
|
#stage 2
|
||||||
|
hx2 = self.stage2(hx)
|
||||||
|
hx = self.pool23(hx2)
|
||||||
|
|
||||||
|
#stage 3
|
||||||
|
hx3 = self.stage3(hx)
|
||||||
|
hx = self.pool34(hx3)
|
||||||
|
|
||||||
|
#stage 4
|
||||||
|
hx4 = self.stage4(hx)
|
||||||
|
hx = self.pool45(hx4)
|
||||||
|
|
||||||
|
#stage 5
|
||||||
|
hx5 = self.stage5(hx)
|
||||||
|
hx = self.pool56(hx5)
|
||||||
|
|
||||||
|
#stage 6
|
||||||
|
hx6 = self.stage6(hx)
|
||||||
|
|
||||||
|
|
||||||
|
#side output
|
||||||
|
d1 = self.side1(hx1)
|
||||||
|
d1 = _upsample_like(d1,x)
|
||||||
|
|
||||||
|
d2 = self.side2(hx2)
|
||||||
|
d2 = _upsample_like(d2,x)
|
||||||
|
|
||||||
|
d3 = self.side3(hx3)
|
||||||
|
d3 = _upsample_like(d3,x)
|
||||||
|
|
||||||
|
d4 = self.side4(hx4)
|
||||||
|
d4 = _upsample_like(d4,x)
|
||||||
|
|
||||||
|
d5 = self.side5(hx5)
|
||||||
|
d5 = _upsample_like(d5,x)
|
||||||
|
|
||||||
|
d6 = self.side6(hx6)
|
||||||
|
d6 = _upsample_like(d6,x)
|
||||||
|
|
||||||
|
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||||
|
|
||||||
|
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
|
||||||
|
|
||||||
|
class ISNetDIS(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,in_ch=3,out_ch=1):
|
||||||
|
super(ISNetDIS,self).__init__()
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
||||||
|
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage1 = RSU7(64,32,64)
|
||||||
|
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage2 = RSU6(64,32,128)
|
||||||
|
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage3 = RSU5(128,64,256)
|
||||||
|
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage4 = RSU4(256,128,512)
|
||||||
|
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage5 = RSU4F(512,256,512)
|
||||||
|
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||||
|
|
||||||
|
self.stage6 = RSU4F(512,256,512)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
self.stage5d = RSU4F(1024,256,512)
|
||||||
|
self.stage4d = RSU4(1024,128,256)
|
||||||
|
self.stage3d = RSU5(512,64,128)
|
||||||
|
self.stage2d = RSU6(256,32,64)
|
||||||
|
self.stage1d = RSU7(128,16,64)
|
||||||
|
|
||||||
|
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||||
|
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||||
|
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
||||||
|
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
||||||
|
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||||
|
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||||
|
|
||||||
|
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
||||||
|
|
||||||
|
def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
|
||||||
|
|
||||||
|
# return muti_loss_fusion(preds,targets)
|
||||||
|
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
||||||
|
|
||||||
|
def compute_loss(self, preds, targets):
|
||||||
|
|
||||||
|
# return muti_loss_fusion(preds,targets)
|
||||||
|
return muti_loss_fusion(preds, targets)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
|
||||||
|
hx = x
|
||||||
|
|
||||||
|
hxin = self.conv_in(hx)
|
||||||
|
hx = self.pool_in(hxin)
|
||||||
|
|
||||||
|
#stage 1
|
||||||
|
hx1 = self.stage1(hxin)
|
||||||
|
hx = self.pool12(hx1)
|
||||||
|
|
||||||
|
#stage 2
|
||||||
|
hx2 = self.stage2(hx)
|
||||||
|
hx = self.pool23(hx2)
|
||||||
|
|
||||||
|
#stage 3
|
||||||
|
hx3 = self.stage3(hx)
|
||||||
|
hx = self.pool34(hx3)
|
||||||
|
|
||||||
|
#stage 4
|
||||||
|
hx4 = self.stage4(hx)
|
||||||
|
hx = self.pool45(hx4)
|
||||||
|
|
||||||
|
#stage 5
|
||||||
|
hx5 = self.stage5(hx)
|
||||||
|
hx = self.pool56(hx5)
|
||||||
|
|
||||||
|
#stage 6
|
||||||
|
hx6 = self.stage6(hx)
|
||||||
|
hx6up = _upsample_like(hx6,hx5)
|
||||||
|
|
||||||
|
#-------------------- decoder --------------------
|
||||||
|
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
||||||
|
hx5dup = _upsample_like(hx5d,hx4)
|
||||||
|
|
||||||
|
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
||||||
|
hx4dup = _upsample_like(hx4d,hx3)
|
||||||
|
|
||||||
|
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
||||||
|
hx3dup = _upsample_like(hx3d,hx2)
|
||||||
|
|
||||||
|
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
||||||
|
hx2dup = _upsample_like(hx2d,hx1)
|
||||||
|
|
||||||
|
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
||||||
|
|
||||||
|
|
||||||
|
#side output
|
||||||
|
d1 = self.side1(hx1d)
|
||||||
|
d1 = _upsample_like(d1,x)
|
||||||
|
|
||||||
|
d2 = self.side2(hx2d)
|
||||||
|
d2 = _upsample_like(d2,x)
|
||||||
|
|
||||||
|
d3 = self.side3(hx3d)
|
||||||
|
d3 = _upsample_like(d3,x)
|
||||||
|
|
||||||
|
d4 = self.side4(hx4d)
|
||||||
|
d4 = _upsample_like(d4,x)
|
||||||
|
|
||||||
|
d5 = self.side5(hx5d)
|
||||||
|
d5 = _upsample_like(d5,x)
|
||||||
|
|
||||||
|
d6 = self.side6(hx6)
|
||||||
|
d6 = _upsample_like(d6,x)
|
||||||
|
|
||||||
|
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||||
|
|
||||||
|
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
|
92
IS-Net/pytorch18.yml
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
name: pytorch18
|
||||||
|
channels:
|
||||||
|
- conda-forge
|
||||||
|
- anaconda
|
||||||
|
- pytorch
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=main
|
||||||
|
- _openmp_mutex=4.5=1_gnu
|
||||||
|
- blas=1.0=mkl
|
||||||
|
- brotli=1.0.9=he6710b0_2
|
||||||
|
- bzip2=1.0.8=h7b6447c_0
|
||||||
|
- ca-certificates=2022.2.1=h06a4308_0
|
||||||
|
- certifi=2021.10.8=py37h06a4308_2
|
||||||
|
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
||||||
|
- colorama=0.4.4=pyhd3eb1b0_0
|
||||||
|
- cudatoolkit=10.2.89=hfd86e86_1
|
||||||
|
- cycler=0.11.0=pyhd3eb1b0_0
|
||||||
|
- cytoolz=0.11.0=py37h7b6447c_0
|
||||||
|
- dask-core=2021.10.0=pyhd3eb1b0_0
|
||||||
|
- ffmpeg=4.3=hf484d3e_0
|
||||||
|
- fonttools=4.25.0=pyhd3eb1b0_0
|
||||||
|
- freetype=2.11.0=h70c0345_0
|
||||||
|
- fsspec=2022.2.0=pyhd3eb1b0_0
|
||||||
|
- gmp=6.2.1=h2531618_2
|
||||||
|
- gnutls=3.6.15=he1e5248_0
|
||||||
|
- imageio=2.9.0=pyhd3eb1b0_0
|
||||||
|
- intel-openmp=2021.4.0=h06a4308_3561
|
||||||
|
- jpeg=9b=h024ee3a_2
|
||||||
|
- kiwisolver=1.3.2=py37h295c915_0
|
||||||
|
- lame=3.100=h7b6447c_0
|
||||||
|
- lcms2=2.12=h3be6417_0
|
||||||
|
- ld_impl_linux-64=2.35.1=h7274673_9
|
||||||
|
- libffi=3.3=he6710b0_2
|
||||||
|
- libgcc-ng=9.3.0=h5101ec6_17
|
||||||
|
- libgfortran-ng=7.5.0=ha8ba4b0_17
|
||||||
|
- libgfortran4=7.5.0=ha8ba4b0_17
|
||||||
|
- libgomp=9.3.0=h5101ec6_17
|
||||||
|
- libiconv=1.15=h63c8f33_5
|
||||||
|
- libidn2=2.3.2=h7f8727e_0
|
||||||
|
- libpng=1.6.37=hbc83047_0
|
||||||
|
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||||
|
- libtasn1=4.16.0=h27cfd23_0
|
||||||
|
- libtiff=4.2.0=h85742a9_0
|
||||||
|
- libunistring=0.9.10=h27cfd23_0
|
||||||
|
- libuv=1.40.0=h7b6447c_0
|
||||||
|
- libwebp-base=1.2.2=h7f8727e_0
|
||||||
|
- locket=0.2.1=py37h06a4308_2
|
||||||
|
- lz4-c=1.9.3=h295c915_1
|
||||||
|
- matplotlib-base=3.5.1=py37ha18d171_1
|
||||||
|
- mkl=2021.4.0=h06a4308_640
|
||||||
|
- mkl-service=2.4.0=py37h7f8727e_0
|
||||||
|
- mkl_fft=1.3.1=py37hd3c417c_0
|
||||||
|
- mkl_random=1.2.2=py37h51133e4_0
|
||||||
|
- munkres=1.1.4=py_0
|
||||||
|
- ncurses=6.3=h7f8727e_2
|
||||||
|
- nettle=3.7.3=hbbd107a_1
|
||||||
|
- networkx=2.6.3=pyhd3eb1b0_0
|
||||||
|
- ninja=1.10.2=py37hd09550d_3
|
||||||
|
- numpy=1.21.2=py37h20f2e39_0
|
||||||
|
- numpy-base=1.21.2=py37h79a1101_0
|
||||||
|
- olefile=0.46=py37_0
|
||||||
|
- openh264=2.1.1=h4ff587b_0
|
||||||
|
- openssl=1.1.1n=h7f8727e_0
|
||||||
|
- packaging=21.3=pyhd3eb1b0_0
|
||||||
|
- partd=1.2.0=pyhd3eb1b0_1
|
||||||
|
- pillow=8.0.0=py37h9a89aac_0
|
||||||
|
- pip=21.2.2=py37h06a4308_0
|
||||||
|
- pyparsing=3.0.4=pyhd3eb1b0_0
|
||||||
|
- python=3.7.11=h12debd9_0
|
||||||
|
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||||
|
- pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
|
||||||
|
- pywavelets=1.1.1=py37h7b6447c_2
|
||||||
|
- pyyaml=6.0=py37h7f8727e_1
|
||||||
|
- readline=8.1.2=h7f8727e_1
|
||||||
|
- scikit-image=0.15.0=py37hb3f55d8_2
|
||||||
|
- scipy=1.7.3=py37hc147768_0
|
||||||
|
- setuptools=58.0.4=py37h06a4308_0
|
||||||
|
- six=1.16.0=pyhd3eb1b0_1
|
||||||
|
- sqlite=3.38.0=hc218d9a_0
|
||||||
|
- tk=8.6.11=h1ccaba5_0
|
||||||
|
- toolz=0.11.2=pyhd3eb1b0_0
|
||||||
|
- torchaudio=0.8.0=py37
|
||||||
|
- torchvision=0.9.0=py37_cu102
|
||||||
|
- tqdm=4.63.0=pyhd8ed1ab_0
|
||||||
|
- typing_extensions=3.10.0.2=pyh06a4308_0
|
||||||
|
- wheel=0.37.1=pyhd3eb1b0_0
|
||||||
|
- xz=5.2.5=h7b6447c_0
|
||||||
|
- yaml=0.2.5=h7b6447c_0
|
||||||
|
- zlib=1.2.11=h7f8727e_4
|
||||||
|
- zstd=1.4.9=haebb681_0
|
||||||
|
prefix: /home/solar/anaconda3/envs/pytorch18
|
87
IS-Net/requirements.txt
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# This file may be used to create an environment using:
|
||||||
|
# $ conda create --name <env> --file <this file>
|
||||||
|
# platform: linux-64
|
||||||
|
_libgcc_mutex=0.1=main
|
||||||
|
_openmp_mutex=4.5=1_gnu
|
||||||
|
blas=1.0=mkl
|
||||||
|
brotli=1.0.9=he6710b0_2
|
||||||
|
bzip2=1.0.8=h7b6447c_0
|
||||||
|
ca-certificates=2022.2.1=h06a4308_0
|
||||||
|
certifi=2021.10.8=py37h06a4308_2
|
||||||
|
cloudpickle=2.0.0=pyhd3eb1b0_0
|
||||||
|
colorama=0.4.4=pyhd3eb1b0_0
|
||||||
|
cudatoolkit=10.2.89=hfd86e86_1
|
||||||
|
cycler=0.11.0=pyhd3eb1b0_0
|
||||||
|
cytoolz=0.11.0=py37h7b6447c_0
|
||||||
|
dask-core=2021.10.0=pyhd3eb1b0_0
|
||||||
|
ffmpeg=4.3=hf484d3e_0
|
||||||
|
fonttools=4.25.0=pyhd3eb1b0_0
|
||||||
|
freetype=2.11.0=h70c0345_0
|
||||||
|
fsspec=2022.2.0=pyhd3eb1b0_0
|
||||||
|
gmp=6.2.1=h2531618_2
|
||||||
|
gnutls=3.6.15=he1e5248_0
|
||||||
|
imageio=2.9.0=pyhd3eb1b0_0
|
||||||
|
intel-openmp=2021.4.0=h06a4308_3561
|
||||||
|
jpeg=9b=h024ee3a_2
|
||||||
|
kiwisolver=1.3.2=py37h295c915_0
|
||||||
|
lame=3.100=h7b6447c_0
|
||||||
|
lcms2=2.12=h3be6417_0
|
||||||
|
ld_impl_linux-64=2.35.1=h7274673_9
|
||||||
|
libffi=3.3=he6710b0_2
|
||||||
|
libgcc-ng=9.3.0=h5101ec6_17
|
||||||
|
libgfortran-ng=7.5.0=ha8ba4b0_17
|
||||||
|
libgfortran4=7.5.0=ha8ba4b0_17
|
||||||
|
libgomp=9.3.0=h5101ec6_17
|
||||||
|
libiconv=1.15=h63c8f33_5
|
||||||
|
libidn2=2.3.2=h7f8727e_0
|
||||||
|
libpng=1.6.37=hbc83047_0
|
||||||
|
libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||||
|
libtasn1=4.16.0=h27cfd23_0
|
||||||
|
libtiff=4.2.0=h85742a9_0
|
||||||
|
libunistring=0.9.10=h27cfd23_0
|
||||||
|
libuv=1.40.0=h7b6447c_0
|
||||||
|
libwebp-base=1.2.2=h7f8727e_0
|
||||||
|
locket=0.2.1=py37h06a4308_2
|
||||||
|
lz4-c=1.9.3=h295c915_1
|
||||||
|
matplotlib-base=3.5.1=py37ha18d171_1
|
||||||
|
mkl=2021.4.0=h06a4308_640
|
||||||
|
mkl-service=2.4.0=py37h7f8727e_0
|
||||||
|
mkl_fft=1.3.1=py37hd3c417c_0
|
||||||
|
mkl_random=1.2.2=py37h51133e4_0
|
||||||
|
munkres=1.1.4=py_0
|
||||||
|
ncurses=6.3=h7f8727e_2
|
||||||
|
nettle=3.7.3=hbbd107a_1
|
||||||
|
networkx=2.6.3=pyhd3eb1b0_0
|
||||||
|
ninja=1.10.2=py37hd09550d_3
|
||||||
|
numpy=1.21.2=py37h20f2e39_0
|
||||||
|
numpy-base=1.21.2=py37h79a1101_0
|
||||||
|
olefile=0.46=py37_0
|
||||||
|
openh264=2.1.1=h4ff587b_0
|
||||||
|
openssl=1.1.1n=h7f8727e_0
|
||||||
|
packaging=21.3=pyhd3eb1b0_0
|
||||||
|
partd=1.2.0=pyhd3eb1b0_1
|
||||||
|
pillow=8.0.0=py37h9a89aac_0
|
||||||
|
pip=21.2.2=py37h06a4308_0
|
||||||
|
pyparsing=3.0.4=pyhd3eb1b0_0
|
||||||
|
python=3.7.11=h12debd9_0
|
||||||
|
python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||||
|
pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
|
||||||
|
pywavelets=1.1.1=py37h7b6447c_2
|
||||||
|
pyyaml=6.0=py37h7f8727e_1
|
||||||
|
readline=8.1.2=h7f8727e_1
|
||||||
|
scikit-image=0.15.0=py37hb3f55d8_2
|
||||||
|
scipy=1.7.3=py37hc147768_0
|
||||||
|
setuptools=58.0.4=py37h06a4308_0
|
||||||
|
six=1.16.0=pyhd3eb1b0_1
|
||||||
|
sqlite=3.38.0=hc218d9a_0
|
||||||
|
tk=8.6.11=h1ccaba5_0
|
||||||
|
toolz=0.11.2=pyhd3eb1b0_0
|
||||||
|
torchaudio=0.8.0=py37
|
||||||
|
torchvision=0.9.0=py37_cu102
|
||||||
|
tqdm=4.63.0=pyhd8ed1ab_0
|
||||||
|
typing_extensions=3.10.0.2=pyh06a4308_0
|
||||||
|
wheel=0.37.1=pyhd3eb1b0_0
|
||||||
|
xz=5.2.5=h7b6447c_0
|
||||||
|
yaml=0.2.5=h7b6447c_0
|
||||||
|
zlib=1.2.11=h7f8727e_4
|
||||||
|
zstd=1.4.9=haebb681_0
|
713
IS-Net/train_valid_inference_main.py
Normal file
@ -0,0 +1,713 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from skimage import io
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch, gc
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
|
||||||
|
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
|
||||||
|
from models import *
|
||||||
|
|
||||||
|
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
||||||
|
|
||||||
|
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
||||||
|
# cache_size = hypar["cache_size"],
|
||||||
|
# cache_boost = hypar["cache_boost_train"],
|
||||||
|
# my_transforms = [
|
||||||
|
# GOSRandomHFlip(),
|
||||||
|
# # GOSResize(hypar["input_size"]),
|
||||||
|
# # GOSRandomCrop(hypar["crop_size"]),
|
||||||
|
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
||||||
|
# ],
|
||||||
|
# batch_size = hypar["batch_size_train"],
|
||||||
|
# shuffle = True)
|
||||||
|
|
||||||
|
torch.manual_seed(hypar["seed"])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(hypar["seed"])
|
||||||
|
|
||||||
|
print("define gt encoder ...")
|
||||||
|
net = ISNetGTEncoder() #UNETGTENCODERCombine()
|
||||||
|
## load the existing model gt encoder
|
||||||
|
if(hypar["gt_encoder_model"]!=""):
|
||||||
|
model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"]
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.load_state_dict(torch.load(model_path))
|
||||||
|
net.cuda()
|
||||||
|
else:
|
||||||
|
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
||||||
|
print("gt encoder restored from the saved weights ...")
|
||||||
|
return net ############
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.cuda()
|
||||||
|
|
||||||
|
print("--- define optimizer for GT Encoder---")
|
||||||
|
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
|
|
||||||
|
model_path = hypar["model_path"]
|
||||||
|
model_save_fre = hypar["model_save_fre"]
|
||||||
|
max_ite = hypar["max_ite"]
|
||||||
|
batch_size_train = hypar["batch_size_train"]
|
||||||
|
batch_size_valid = hypar["batch_size_valid"]
|
||||||
|
|
||||||
|
if(not os.path.exists(model_path)):
|
||||||
|
os.mkdir(model_path)
|
||||||
|
|
||||||
|
ite_num = hypar["start_ite"] # count the total iteration number
|
||||||
|
ite_num4val = 0 #
|
||||||
|
running_loss = 0.0 # count the toal loss
|
||||||
|
running_tar_loss = 0.0 # count the target output loss
|
||||||
|
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
||||||
|
|
||||||
|
train_num = train_datasets[0].__len__()
|
||||||
|
|
||||||
|
net.train()
|
||||||
|
|
||||||
|
start_last = time.time()
|
||||||
|
gos_dataloader = train_dataloaders[0]
|
||||||
|
epoch_num = hypar["max_epoch_num"]
|
||||||
|
notgood_cnt = 0
|
||||||
|
for epoch in range(epoch_num): ## set the epoch num as 100000
|
||||||
|
|
||||||
|
for i, data in enumerate(gos_dataloader):
|
||||||
|
|
||||||
|
if(ite_num >= max_ite):
|
||||||
|
print("Training Reached the Maximal Iteration Number ", max_ite)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# start_read = time.time()
|
||||||
|
ite_num = ite_num + 1
|
||||||
|
ite_num4val = ite_num4val + 1
|
||||||
|
|
||||||
|
# get the inputs
|
||||||
|
labels = data['label']
|
||||||
|
|
||||||
|
if(hypar["model_digit"]=="full"):
|
||||||
|
labels = labels.type(torch.FloatTensor)
|
||||||
|
else:
|
||||||
|
labels = labels.type(torch.HalfTensor)
|
||||||
|
|
||||||
|
# wrap them in Variable
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
labels_v = Variable(labels.cuda(), requires_grad=False)
|
||||||
|
else:
|
||||||
|
labels_v = Variable(labels, requires_grad=False)
|
||||||
|
|
||||||
|
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
||||||
|
|
||||||
|
# y zero the parameter gradients
|
||||||
|
start_inf_loss_back = time.time()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
ds, fs = net(labels_v)#net(inputs_v)
|
||||||
|
loss2, loss = net.compute_loss(ds, labels_v)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
running_tar_loss += loss2.item()
|
||||||
|
|
||||||
|
# del outputs, loss
|
||||||
|
del ds, loss2, loss
|
||||||
|
end_inf_loss_back = time.time()-start_inf_loss_back
|
||||||
|
|
||||||
|
print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
||||||
|
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
||||||
|
start_last = time.time()
|
||||||
|
|
||||||
|
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
||||||
|
notgood_cnt += 1
|
||||||
|
# net.eval()
|
||||||
|
# tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
||||||
|
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch)
|
||||||
|
|
||||||
|
net.train() # resume train
|
||||||
|
|
||||||
|
tmp_out = 0
|
||||||
|
print("last_f1:",last_f1)
|
||||||
|
print("tmp_f1:",tmp_f1)
|
||||||
|
for fi in range(len(last_f1)):
|
||||||
|
if(tmp_f1[fi]>last_f1[fi]):
|
||||||
|
tmp_out = 1
|
||||||
|
print("tmp_out:",tmp_out)
|
||||||
|
if(tmp_out):
|
||||||
|
notgood_cnt = 0
|
||||||
|
last_f1 = tmp_f1
|
||||||
|
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
||||||
|
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
||||||
|
maxf1 = '_'.join(tmp_f1_str)
|
||||||
|
meanM = '_'.join(tmp_mae_str)
|
||||||
|
# .cpu().detach().numpy()
|
||||||
|
model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\
|
||||||
|
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
||||||
|
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
||||||
|
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
||||||
|
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
||||||
|
"_maxF1_" + maxf1 + \
|
||||||
|
"_mae_" + meanM + \
|
||||||
|
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
||||||
|
torch.save(net.state_dict(), model_path + model_name)
|
||||||
|
|
||||||
|
running_loss = 0.0
|
||||||
|
running_tar_loss = 0.0
|
||||||
|
ite_num4val = 0
|
||||||
|
|
||||||
|
if(tmp_f1[0]>0.99):
|
||||||
|
print("GT encoder is well-trained and obtained...")
|
||||||
|
return net
|
||||||
|
|
||||||
|
if(notgood_cnt >= hypar["early_stop"]):
|
||||||
|
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
print("Training Reaches The Maximum Epoch Number")
|
||||||
|
return net
|
||||||
|
|
||||||
|
def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||||
|
net.eval()
|
||||||
|
print("Validating...")
|
||||||
|
epoch_num = hypar["max_epoch_num"]
|
||||||
|
|
||||||
|
val_loss = 0.0
|
||||||
|
tar_loss = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
tmp_f1 = []
|
||||||
|
tmp_mae = []
|
||||||
|
tmp_time = []
|
||||||
|
|
||||||
|
start_valid = time.time()
|
||||||
|
for k in range(len(valid_dataloaders)):
|
||||||
|
|
||||||
|
valid_dataloader = valid_dataloaders[k]
|
||||||
|
valid_dataset = valid_datasets[k]
|
||||||
|
|
||||||
|
val_num = valid_dataset.__len__()
|
||||||
|
mybins = np.arange(0,256)
|
||||||
|
PRE = np.zeros((val_num,len(mybins)-1))
|
||||||
|
REC = np.zeros((val_num,len(mybins)-1))
|
||||||
|
F1 = np.zeros((val_num,len(mybins)-1))
|
||||||
|
MAE = np.zeros((val_num))
|
||||||
|
|
||||||
|
val_cnt = 0.0
|
||||||
|
for i_val, data_val in enumerate(valid_dataloader):
|
||||||
|
|
||||||
|
# imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
||||||
|
imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape']
|
||||||
|
|
||||||
|
if(hypar["model_digit"]=="full"):
|
||||||
|
labels_val = labels_val.type(torch.FloatTensor)
|
||||||
|
else:
|
||||||
|
labels_val = labels_val.type(torch.HalfTensor)
|
||||||
|
|
||||||
|
# wrap them in Variable
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
labels_val_v = Variable(labels_val.cuda(), requires_grad=False)
|
||||||
|
else:
|
||||||
|
labels_val_v = Variable(labels_val,requires_grad=False)
|
||||||
|
|
||||||
|
t_start = time.time()
|
||||||
|
ds_val = net(labels_val_v)[0]
|
||||||
|
t_end = time.time()-t_start
|
||||||
|
tmp_time.append(t_end)
|
||||||
|
|
||||||
|
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
||||||
|
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
||||||
|
|
||||||
|
# compute F measure
|
||||||
|
for t in range(hypar["batch_size_valid"]):
|
||||||
|
val_cnt = val_cnt + 1.0
|
||||||
|
print("num of val: ", val_cnt)
|
||||||
|
i_test = imidx_val[t].data.numpy()
|
||||||
|
|
||||||
|
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
|
||||||
|
|
||||||
|
## recover the prediction spatial size to the orignal image size
|
||||||
|
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
||||||
|
|
||||||
|
ma = torch.max(pred_val)
|
||||||
|
mi = torch.min(pred_val)
|
||||||
|
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
||||||
|
# pred_val = normPRED(pred_val)
|
||||||
|
|
||||||
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||||
|
with torch.no_grad():
|
||||||
|
gt = torch.tensor(gt).cuda()
|
||||||
|
|
||||||
|
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
||||||
|
|
||||||
|
PRE[i_test,:]=pre
|
||||||
|
REC[i_test,:] = rec
|
||||||
|
F1[i_test,:] = f1
|
||||||
|
MAE[i_test] = mae
|
||||||
|
|
||||||
|
del ds_val, gt
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# if(loss_val.data[0]>1):
|
||||||
|
val_loss += loss_val.item()#data[0]
|
||||||
|
tar_loss += loss2_val.item()#data[0]
|
||||||
|
|
||||||
|
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
||||||
|
|
||||||
|
del loss2_val, loss_val
|
||||||
|
|
||||||
|
print('============================')
|
||||||
|
PRE_m = np.mean(PRE,0)
|
||||||
|
REC_m = np.mean(REC,0)
|
||||||
|
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
||||||
|
# print('--------------:', np.mean(f1_m))
|
||||||
|
tmp_f1.append(np.amax(f1_m))
|
||||||
|
tmp_mae.append(np.mean(MAE))
|
||||||
|
print("The max F1 Score: %f"%(np.max(f1_m)))
|
||||||
|
print("MAE: ", np.mean(MAE))
|
||||||
|
|
||||||
|
# print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
|
||||||
|
|
||||||
|
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
||||||
|
|
||||||
|
def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
||||||
|
|
||||||
|
if hypar["interm_sup"]:
|
||||||
|
print("Get the gt encoder ...")
|
||||||
|
featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val)
|
||||||
|
## freeze the weights of gt encoder
|
||||||
|
for param in featurenet.parameters():
|
||||||
|
param.requires_grad=False
|
||||||
|
|
||||||
|
|
||||||
|
model_path = hypar["model_path"]
|
||||||
|
model_save_fre = hypar["model_save_fre"]
|
||||||
|
max_ite = hypar["max_ite"]
|
||||||
|
batch_size_train = hypar["batch_size_train"]
|
||||||
|
batch_size_valid = hypar["batch_size_valid"]
|
||||||
|
|
||||||
|
if(not os.path.exists(model_path)):
|
||||||
|
os.mkdir(model_path)
|
||||||
|
|
||||||
|
ite_num = hypar["start_ite"] # count the toal iteration number
|
||||||
|
ite_num4val = 0 #
|
||||||
|
running_loss = 0.0 # count the toal loss
|
||||||
|
running_tar_loss = 0.0 # count the target output loss
|
||||||
|
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
||||||
|
|
||||||
|
train_num = train_datasets[0].__len__()
|
||||||
|
|
||||||
|
net.train()
|
||||||
|
|
||||||
|
start_last = time.time()
|
||||||
|
gos_dataloader = train_dataloaders[0]
|
||||||
|
epoch_num = hypar["max_epoch_num"]
|
||||||
|
notgood_cnt = 0
|
||||||
|
for epoch in range(epoch_num): ## set the epoch num as 100000
|
||||||
|
|
||||||
|
for i, data in enumerate(gos_dataloader):
|
||||||
|
|
||||||
|
if(ite_num >= max_ite):
|
||||||
|
print("Training Reached the Maximal Iteration Number ", max_ite)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# start_read = time.time()
|
||||||
|
ite_num = ite_num + 1
|
||||||
|
ite_num4val = ite_num4val + 1
|
||||||
|
|
||||||
|
# get the inputs
|
||||||
|
inputs, labels = data['image'], data['label']
|
||||||
|
|
||||||
|
if(hypar["model_digit"]=="full"):
|
||||||
|
inputs = inputs.type(torch.FloatTensor)
|
||||||
|
labels = labels.type(torch.FloatTensor)
|
||||||
|
else:
|
||||||
|
inputs = inputs.type(torch.HalfTensor)
|
||||||
|
labels = labels.type(torch.HalfTensor)
|
||||||
|
|
||||||
|
# wrap them in Variable
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
|
||||||
|
else:
|
||||||
|
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
|
||||||
|
|
||||||
|
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
||||||
|
|
||||||
|
# y zero the parameter gradients
|
||||||
|
start_inf_loss_back = time.time()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if hypar["interm_sup"]:
|
||||||
|
# forward + backward + optimize
|
||||||
|
ds,dfs = net(inputs_v)
|
||||||
|
_,fs = featurenet(labels_v) ## extract the gt encodings
|
||||||
|
loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE')
|
||||||
|
else:
|
||||||
|
# forward + backward + optimize
|
||||||
|
ds,_ = net(inputs_v)
|
||||||
|
loss2, loss = muti_loss_fusion(ds, labels_v)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# # print statistics
|
||||||
|
running_loss += loss.item()
|
||||||
|
running_tar_loss += loss2.item()
|
||||||
|
|
||||||
|
# del outputs, loss
|
||||||
|
del ds, loss2, loss
|
||||||
|
end_inf_loss_back = time.time()-start_inf_loss_back
|
||||||
|
|
||||||
|
print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
||||||
|
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
||||||
|
start_last = time.time()
|
||||||
|
|
||||||
|
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
||||||
|
notgood_cnt += 1
|
||||||
|
net.eval()
|
||||||
|
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
||||||
|
net.train() # resume train
|
||||||
|
|
||||||
|
tmp_out = 0
|
||||||
|
print("last_f1:",last_f1)
|
||||||
|
print("tmp_f1:",tmp_f1)
|
||||||
|
for fi in range(len(last_f1)):
|
||||||
|
if(tmp_f1[fi]>last_f1[fi]):
|
||||||
|
tmp_out = 1
|
||||||
|
print("tmp_out:",tmp_out)
|
||||||
|
if(tmp_out):
|
||||||
|
notgood_cnt = 0
|
||||||
|
last_f1 = tmp_f1
|
||||||
|
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
||||||
|
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
||||||
|
maxf1 = '_'.join(tmp_f1_str)
|
||||||
|
meanM = '_'.join(tmp_mae_str)
|
||||||
|
# .cpu().detach().numpy()
|
||||||
|
model_name = "/gpu_itr_"+str(ite_num)+\
|
||||||
|
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
||||||
|
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
||||||
|
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
||||||
|
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
||||||
|
"_maxF1_" + maxf1 + \
|
||||||
|
"_mae_" + meanM + \
|
||||||
|
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
||||||
|
torch.save(net.state_dict(), model_path + model_name)
|
||||||
|
|
||||||
|
running_loss = 0.0
|
||||||
|
running_tar_loss = 0.0
|
||||||
|
ite_num4val = 0
|
||||||
|
|
||||||
|
if(notgood_cnt >= hypar["early_stop"]):
|
||||||
|
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
print("Training Reaches The Maximum Epoch Number")
|
||||||
|
|
||||||
|
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||||
|
net.eval()
|
||||||
|
print("Validating...")
|
||||||
|
epoch_num = hypar["max_epoch_num"]
|
||||||
|
|
||||||
|
val_loss = 0.0
|
||||||
|
tar_loss = 0.0
|
||||||
|
val_cnt = 0.0
|
||||||
|
|
||||||
|
tmp_f1 = []
|
||||||
|
tmp_mae = []
|
||||||
|
tmp_time = []
|
||||||
|
|
||||||
|
start_valid = time.time()
|
||||||
|
|
||||||
|
for k in range(len(valid_dataloaders)):
|
||||||
|
|
||||||
|
valid_dataloader = valid_dataloaders[k]
|
||||||
|
valid_dataset = valid_datasets[k]
|
||||||
|
|
||||||
|
val_num = valid_dataset.__len__()
|
||||||
|
mybins = np.arange(0,256)
|
||||||
|
PRE = np.zeros((val_num,len(mybins)-1))
|
||||||
|
REC = np.zeros((val_num,len(mybins)-1))
|
||||||
|
F1 = np.zeros((val_num,len(mybins)-1))
|
||||||
|
MAE = np.zeros((val_num))
|
||||||
|
|
||||||
|
for i_val, data_val in enumerate(valid_dataloader):
|
||||||
|
val_cnt = val_cnt + 1.0
|
||||||
|
imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
||||||
|
|
||||||
|
if(hypar["model_digit"]=="full"):
|
||||||
|
inputs_val = inputs_val.type(torch.FloatTensor)
|
||||||
|
labels_val = labels_val.type(torch.FloatTensor)
|
||||||
|
else:
|
||||||
|
inputs_val = inputs_val.type(torch.HalfTensor)
|
||||||
|
labels_val = labels_val.type(torch.HalfTensor)
|
||||||
|
|
||||||
|
# wrap them in Variable
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False)
|
||||||
|
else:
|
||||||
|
inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False)
|
||||||
|
|
||||||
|
t_start = time.time()
|
||||||
|
ds_val = net(inputs_val_v)[0]
|
||||||
|
t_end = time.time()-t_start
|
||||||
|
tmp_time.append(t_end)
|
||||||
|
|
||||||
|
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
||||||
|
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
||||||
|
|
||||||
|
# compute F measure
|
||||||
|
for t in range(hypar["batch_size_valid"]):
|
||||||
|
i_test = imidx_val[t].data.numpy()
|
||||||
|
|
||||||
|
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
|
||||||
|
|
||||||
|
## recover the prediction spatial size to the orignal image size
|
||||||
|
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
||||||
|
|
||||||
|
# pred_val = normPRED(pred_val)
|
||||||
|
ma = torch.max(pred_val)
|
||||||
|
mi = torch.min(pred_val)
|
||||||
|
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
||||||
|
|
||||||
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||||
|
with torch.no_grad():
|
||||||
|
gt = torch.tensor(gt).cuda()
|
||||||
|
|
||||||
|
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
||||||
|
|
||||||
|
|
||||||
|
PRE[i_test,:]=pre
|
||||||
|
REC[i_test,:] = rec
|
||||||
|
F1[i_test,:] = f1
|
||||||
|
MAE[i_test] = mae
|
||||||
|
|
||||||
|
del ds_val, gt
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# if(loss_val.data[0]>1):
|
||||||
|
val_loss += loss_val.item()#data[0]
|
||||||
|
tar_loss += loss2_val.item()#data[0]
|
||||||
|
|
||||||
|
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
||||||
|
|
||||||
|
del loss2_val, loss_val
|
||||||
|
|
||||||
|
print('============================')
|
||||||
|
PRE_m = np.mean(PRE,0)
|
||||||
|
REC_m = np.mean(REC,0)
|
||||||
|
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
||||||
|
|
||||||
|
tmp_f1.append(np.amax(f1_m))
|
||||||
|
tmp_mae.append(np.mean(MAE))
|
||||||
|
|
||||||
|
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
||||||
|
|
||||||
|
def main(train_datasets,
|
||||||
|
valid_datasets,
|
||||||
|
hypar): # model: "train", "test"
|
||||||
|
|
||||||
|
### --- Step 1: Build datasets and dataloaders ---
|
||||||
|
dataloaders_train = []
|
||||||
|
dataloaders_valid = []
|
||||||
|
|
||||||
|
if(hypar["mode"]=="train"):
|
||||||
|
print("--- create training dataloader ---")
|
||||||
|
## collect training dataset
|
||||||
|
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
|
||||||
|
## build dataloader for training datasets
|
||||||
|
train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
||||||
|
cache_size = hypar["cache_size"],
|
||||||
|
cache_boost = hypar["cache_boost_train"],
|
||||||
|
my_transforms = [
|
||||||
|
GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
|
||||||
|
# GOSResize(hypar["input_size"]),
|
||||||
|
# GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
|
||||||
|
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
||||||
|
],
|
||||||
|
batch_size = hypar["batch_size_train"],
|
||||||
|
shuffle = True)
|
||||||
|
train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list,
|
||||||
|
cache_size = hypar["cache_size"],
|
||||||
|
cache_boost = hypar["cache_boost_train"],
|
||||||
|
my_transforms = [
|
||||||
|
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
||||||
|
],
|
||||||
|
batch_size = hypar["batch_size_valid"],
|
||||||
|
shuffle = False)
|
||||||
|
print(len(train_dataloaders), " train dataloaders created")
|
||||||
|
|
||||||
|
print("--- create valid dataloader ---")
|
||||||
|
## build dataloader for validation or testing
|
||||||
|
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
|
||||||
|
## build dataloader for training datasets
|
||||||
|
valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list,
|
||||||
|
cache_size = hypar["cache_size"],
|
||||||
|
cache_boost = hypar["cache_boost_valid"],
|
||||||
|
my_transforms = [
|
||||||
|
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
||||||
|
# GOSResize(hypar["input_size"])
|
||||||
|
],
|
||||||
|
batch_size=hypar["batch_size_valid"],
|
||||||
|
shuffle=False)
|
||||||
|
print(len(valid_dataloaders), " valid dataloaders created")
|
||||||
|
# print(valid_datasets[0]["data_name"])
|
||||||
|
|
||||||
|
### --- Step 2: Build Model and Optimizer ---
|
||||||
|
print("--- build model ---")
|
||||||
|
net = hypar["model"]#GOSNETINC(3,1)
|
||||||
|
|
||||||
|
# convert to half precision
|
||||||
|
if(hypar["model_digit"]=="half"):
|
||||||
|
net.half()
|
||||||
|
for layer in net.modules():
|
||||||
|
if isinstance(layer, nn.BatchNorm2d):
|
||||||
|
layer.float()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.cuda()
|
||||||
|
|
||||||
|
if(hypar["restore_model"]!=""):
|
||||||
|
print("restore model from:")
|
||||||
|
print(hypar["model_path"]+"/"+hypar["restore_model"])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]))
|
||||||
|
else:
|
||||||
|
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"))
|
||||||
|
|
||||||
|
print("--- define optimizer ---")
|
||||||
|
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
|
|
||||||
|
### --- Step 3: Train or Valid Model ---
|
||||||
|
if(hypar["mode"]=="train"):
|
||||||
|
train(net,
|
||||||
|
optimizer,
|
||||||
|
train_dataloaders,
|
||||||
|
train_datasets,
|
||||||
|
valid_dataloaders,
|
||||||
|
valid_datasets,
|
||||||
|
hypar,
|
||||||
|
train_dataloaders_val, train_datasets_val)
|
||||||
|
else:
|
||||||
|
valid(net,
|
||||||
|
valid_dataloaders,
|
||||||
|
valid_datasets,
|
||||||
|
hypar)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
|
||||||
|
## configure the train, valid and inference datasets
|
||||||
|
train_datasets, valid_datasets = [], []
|
||||||
|
dataset_1, dataset_1 = {}, {}
|
||||||
|
|
||||||
|
dataset_tr = {"name": "DIS5K-TR",
|
||||||
|
"im_dir": "../DIS5K/DIS-TR/im",
|
||||||
|
"gt_dir": "../DIS5K/DIS-TR/gt",
|
||||||
|
"im_ext": ".jpg",
|
||||||
|
"gt_ext": ".png",
|
||||||
|
"cache_dir":"../DIS5K-Cache/DIS-TR"}
|
||||||
|
|
||||||
|
dataset_vd = {"name": "DIS5K-VD",
|
||||||
|
"im_dir": "../DIS5K/DIS-VD/im",
|
||||||
|
"gt_dir": "../DIS5K/DIS-VD/gt",
|
||||||
|
"im_ext": ".jpg",
|
||||||
|
"gt_ext": ".png",
|
||||||
|
"cache_dir":"../DIS5K-Cache/DIS-VD"}
|
||||||
|
|
||||||
|
dataset_te1 = {"name": "DIS5K-TE1",
|
||||||
|
"im_dir": "../DIS5K/DIS-TE1/im",
|
||||||
|
"gt_dir": "../DIS5K/DIS-TE1/gt",
|
||||||
|
"im_ext": ".jpg",
|
||||||
|
"gt_ext": ".png",
|
||||||
|
"cache_dir":"../DIS5K-Cache/DIS-TE1"}
|
||||||
|
|
||||||
|
dataset_te2 = {"name": "DIS5K-TE2",
|
||||||
|
"im_dir": "../DIS5K/DIS-TE2/im",
|
||||||
|
"gt_dir": "../DIS5K/DIS-TE2/gt",
|
||||||
|
"im_ext": ".jpg",
|
||||||
|
"gt_ext": ".png",
|
||||||
|
"cache_dir":"../DIS5K-Cache/DIS-TE2"}
|
||||||
|
|
||||||
|
dataset_te3 = {"name": "DIS5K-TE3",
|
||||||
|
"im_dir": "../DIS5K/DIS-TE3/im",
|
||||||
|
"gt_dir": "../DIS5K/DIS-TE3/gt",
|
||||||
|
"im_ext": ".jpg",
|
||||||
|
"gt_ext": ".png",
|
||||||
|
"cache_dir":"../DIS5K-Cache/DIS-TE3"}
|
||||||
|
|
||||||
|
dataset_te4 = {"name": "DIS5K-TE4",
|
||||||
|
"im_dir": "../DIS5K/DIS-TE4/im",
|
||||||
|
"gt_dir": "../DIS5K/DIS-TE4/gt",
|
||||||
|
"im_ext": ".jpg",
|
||||||
|
"gt_ext": ".png",
|
||||||
|
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
|
||||||
|
|
||||||
|
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
|
||||||
|
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
|
||||||
|
valid_datasets = [dataset_vd] #, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
|
||||||
|
|
||||||
|
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
|
||||||
|
hypar = {}
|
||||||
|
|
||||||
|
## -- 2.1. configure the model saving or restoring path --
|
||||||
|
hypar["mode"] = "train"
|
||||||
|
## "train": for training,
|
||||||
|
## "valid": for validation and inferening,
|
||||||
|
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
|
||||||
|
## otherwise only accuracy will be calculated and no predictions will be saved
|
||||||
|
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
|
||||||
|
|
||||||
|
if hypar["mode"] == "train":
|
||||||
|
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
|
||||||
|
hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
|
||||||
|
hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
|
||||||
|
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
|
||||||
|
hypar["gt_encoder_model"] = ""
|
||||||
|
else: ## configure the segmentation output path and the to-be-used model weights path
|
||||||
|
hypar["valid_out_dir"] = "../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
||||||
|
hypar["model_path"] ="../saved_models/IS-Net" ## load trained weights from this path
|
||||||
|
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
|
||||||
|
|
||||||
|
# if hypar["restore_model"]!="":
|
||||||
|
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
|
||||||
|
|
||||||
|
## -- 2.2. choose floating point accuracy --
|
||||||
|
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
|
||||||
|
hypar["seed"] = 0
|
||||||
|
|
||||||
|
## -- 2.3. cache data spatial size --
|
||||||
|
## To handle large size input images, which take a lot of time for loading in training,
|
||||||
|
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
|
||||||
|
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
|
||||||
|
hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
|
||||||
|
hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
|
||||||
|
|
||||||
|
## --- 2.4. data augmentation parameters ---
|
||||||
|
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
|
||||||
|
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
|
||||||
|
hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the dataloader and it is not in use
|
||||||
|
hypar["random_flip_v"] = 0 ## vertical flip , currently not in use
|
||||||
|
|
||||||
|
## --- 2.5. define model ---
|
||||||
|
print("building model...")
|
||||||
|
hypar["model"] = ISNetDIS() #U2NETFASTFEATURESUP()
|
||||||
|
hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
|
||||||
|
hypar["model_save_fre"] = 2000 ## valid and save model weights every 2000 iterations
|
||||||
|
|
||||||
|
hypar["batch_size_train"] = 8 ## batch size for training
|
||||||
|
hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing
|
||||||
|
print("batch size: ", hypar["batch_size_train"])
|
||||||
|
|
||||||
|
hypar["max_ite"] = 10000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
|
||||||
|
hypar["max_epoch_num"] = 1000000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
|
||||||
|
|
||||||
|
main(train_datasets,
|
||||||
|
valid_datasets,
|
||||||
|
hypar=hypar)
|
161
README.md
@ -2,27 +2,147 @@
|
|||||||
<img width="420" height="320" src="figures/dis-logo-official.png">
|
<img width="420" height="320" src="figures/dis-logo-official.png">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
![dis5k-v1-sailship](figures/dis5k-v1-sailship.jpeg)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
## [Highly Accurate Dichotomous Image Segmentation (ECCV 2022)](https://arxiv.org/pdf/2203.03041.pdf)
|
## [Highly Accurate Dichotomous Image Segmentation (ECCV 2022)](https://arxiv.org/pdf/2203.03041.pdf)
|
||||||
[Xuebin Qin](https://xuebinqin.github.io/), [Hang Dai](https://scholar.google.co.uk/citations?user=6yvjpQQAAAAJ&hl=en), [Xiaobin Hu](https://scholar.google.de/citations?user=3lMuodUAAAAJ&hl=en), [Deng-Ping Fan*](https://dengpingfan.github.io/), [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en) and [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en).
|
[Xuebin Qin](https://xuebinqin.github.io/), [Hang Dai](https://scholar.google.co.uk/citations?user=6yvjpQQAAAAJ&hl=en), [Xiaobin Hu](https://scholar.google.de/citations?user=3lMuodUAAAAJ&hl=en), [Deng-Ping Fan*](https://dengpingfan.github.io/), [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en), [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en).
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
This is the official repo for our new project DIS:
|
## This is the official repo for our newly formulated DIS task:
|
||||||
[**[Project Page]**](https://xuebinqin.github.io/dis/index.html)
|
[**Project Page**](https://xuebinqin.github.io/dis/index.html), [**Arxiv**](https://arxiv.org/pdf/2203.03041.pdf).
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
## Updates !!!
|
## Updates !!!
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## ** (2022-Jul.-17)** Our paper, code and dataset are now officially released!!! Please check our project page for more details: [**Project Page**](https://xuebinqin.github.io/dis/index.html).<br>
|
||||||
** (2022-Jul.-5)** Our DIS work is now accepted by ECCV 2022, the code and dataset will be released before July 17th, 2022. Please be aware of our updates.
|
** (2022-Jul.-5)** Our DIS work is now accepted by ECCV 2022, the code and dataset will be released before July 17th, 2022. Please be aware of our updates.
|
||||||
|
|
||||||
## DEMO
|
<br>
|
||||||
![ship-demo](figures/ship-demo.gif)
|
|
||||||
![bg-removal](figures/bg-removal.gif)
|
|
||||||
![view-move](figures/view-move.gif)
|
|
||||||
![motor-demo](figures/motor-demo.gif)
|
|
||||||
|
|
||||||
## Comparisons Against SOTAs
|
## 1. [Our DIS5K Dataset V1.0 (Version Alias: DIS5K Sailship)](https://xuebinqin.github.io/dis/index.html)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
### Download: [Google Drive](https://drive.google.com/file/d/1jOC2zK0GowBvEt03B7dGugCRDVAoeIqq/view?usp=sharing) or [Baidu Pan 提取码:rtgw](https://pan.baidu.com/s/1y6CQJYledfYyEO0C_Gejpw?pwd=rtgw)
|
||||||
|
|
||||||
|
![dis5k-dataset-v1-sailship](figures/DIS5k-dataset-v1-sailship.png)
|
||||||
|
![complexities-qual](figures/complexities-qual.jpeg)
|
||||||
|
![categories](figures/categories.jpeg)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## 2. APPLICATIONS of Our DIS5K Dataset
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
### 3D Modeling
|
||||||
|
![3d-modeling](figures/3d-modeling.png)
|
||||||
|
|
||||||
|
### Image Editing
|
||||||
|
![ship-demo](figures/ship-demo.gif)
|
||||||
|
### Art Design Materials
|
||||||
|
![bg-removal](figures/bg-removal.gif)
|
||||||
|
### Still Image Animation
|
||||||
|
![view-move](figures/view-move.gif)
|
||||||
|
### AR
|
||||||
|
![motor-demo](figures/motor-demo.gif)
|
||||||
|
### 3D Rendering
|
||||||
|
![video-3d](figures/video-3d.gif)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## 3. Architecture of Our IS-Net
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
![is-net](figures/is-net.png)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## 4. Human Correction Efforts (HCE)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
![hce-metric](figures/hce-metric.png)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## 5. Experimental Results
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
### Predicted Maps, [(Google Drive)](https://drive.google.com/file/d/1FMtDLFrL6xVc41eKlLnuZWMBAErnKv0Y/view?usp=sharing), [(Baidu Pan 提取码:ph1d)](https://pan.baidu.com/s/1WUk2RYYpii2xzrvLna9Fsg?pwd=ph1d), of Our IS-Net and Other SOTAs
|
||||||
|
|
||||||
|
### Qualitative Comparisons Against SOTAs
|
||||||
![qual-comp](figures/qual-comp.jpg)
|
![qual-comp](figures/qual-comp.jpg)
|
||||||
|
|
||||||
|
### Quantitative Comparisons Against SOTAs
|
||||||
|
![qual-comp](figures/quan-comp.png)
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## 6. Run Our Code
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
### (1) Clone this repo
|
||||||
|
```
|
||||||
|
git clone https://github.com/xuebinqin/DIS.git
|
||||||
|
```
|
||||||
|
|
||||||
|
### (2) Configuring the environment: go to the root ```DIS``` folder and run
|
||||||
|
```
|
||||||
|
conda env create -f pytorch18.yml
|
||||||
|
```
|
||||||
|
Or you can check the ```requirements.txt``` to configure the dependancies.
|
||||||
|
|
||||||
|
### (3) Train:
|
||||||
|
#### (a) Open ```train_valid_inference_main.py```, set the path of your to-be-inferenced ```train_datasets``` and ```valid_datasets```, e.g., ```valid_datasets=[dataset_vd]```
|
||||||
|
#### (b) Set the ```hypar["mode"]``` to ```"train"```
|
||||||
|
#### (c) Create a new folder ```your_model_weights``` in the directory ```saved_models``` and set it as the ```hypar["model_path"] ="../saved_models/your_model_weights"``` and make sure ```hypar["valid_out_dir"]```(line 668) is set to ```""```, otherwise the prediction maps of the validation stage will be saved to that directory, which will slow the training speed down
|
||||||
|
#### (d) Run
|
||||||
|
```
|
||||||
|
python train_valid_inference_main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### (4) Inference
|
||||||
|
#### (a). Download the pre-trained weights (for fair academic comparisons only, the optimized model for engineering or common use will be released soon) ```isnet.pth``` from [(Google Drive)](https://drive.google.com/file/d/1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn/view?usp=sharing) or [(Baidu Pan 提取码:xbfk)](https://pan.baidu.com/s/1-X2WutiBkWPt-oakuvZ10w?pwd=xbfk) and store ```isnet.pth``` in ```saved_models/IS-Net```
|
||||||
|
#### (b) Open ```train_valid_inference_main.py```, set the path of your to-be-inferenced ```valid_datasets```, e.g., ```valid_datasets=[dataset_te1, dataset_te2, dataset_te3, dataset_te4]```
|
||||||
|
#### (c) Set the ```hypar["mode"]``` to ```"valid"```
|
||||||
|
#### (d) Set the output directory of your predicted maps, e.g., ```hypar["valid_out_dir"] = "../DIS5K-Results-test"```
|
||||||
|
#### (e) Run
|
||||||
|
```
|
||||||
|
python train_valid_inference_main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### (5) Use of our Human Correction Efforts(HCE) metric, set the ground truth directory ```gt_root``` and the prediction directory ```pred_root```. To reduce the time costs for computing HCE, the skeletion of the DIS5K dataset can be pre-computed and stored in ```gt_ske_root```. If ```gt_ske_root=""```, the HCE code will compute the skeleton online which usually takes a lot for time for large size ground truth. Then, run ```python hce_metric_main.py```. Other metrics are evaluated based on the [SOCToolbox](https://github.com/mczhuge/SOCToolbox).
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## 7. Term of Use
|
||||||
|
Our code and evaluation metric use Apache License 2.0. The Terms of use for our DIS5K dataset is provided as [DIS5K-Dataset-Terms-of-Use.pdf](DIS5K-Dataset-Terms-of-Use.pdf).
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
We would like to thank Dr. [Ibrahim Almakky](https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en) for his helps in implementing the dataloader cache machanism of loading large-size training samples and Jiayi Zhu for his efforts in re-organizing our code and dataset.
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
```
|
```
|
||||||
@InProceedings{qin2022,
|
@InProceedings{qin2022,
|
||||||
author={Xuebin Qin and Hang Dai and Xiaobin Hu and Deng-Ping Fan and Ling Shao and Luc Van Gool},
|
author={Xuebin Qin and Hang Dai and Xiaobin Hu and Deng-Ping Fan and Ling Shao and Luc Van Gool},
|
||||||
@ -31,8 +151,15 @@ This is the official repo for our new project DIS:
|
|||||||
year={2022}
|
year={2022}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
## Our Previous Works
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
## Our Previous Works: [U<sup>2</sup>-Net](https://github.com/xuebinqin/U-2-Net), [BASNet](https://github.com/xuebinqin/BASNet).
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
@InProceedings{Qin_2020_PR,
|
@InProceedings{Qin_2020_PR,
|
||||||
title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
|
title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
|
||||||
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin},
|
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin},
|
||||||
@ -42,13 +169,6 @@ This is the official repo for our new project DIS:
|
|||||||
year = {2020}
|
year = {2020}
|
||||||
}
|
}
|
||||||
|
|
||||||
@article{qin2021boundary,
|
|
||||||
title={Boundary-aware segmentation network for mobile and web applications},
|
|
||||||
author={Qin, Xuebin and Fan, Deng-Ping and Huang, Chenyang and Diagne, Cyril and Zhang, Zichen and Sant'Anna, Adri{\`a} Cabeza and Suarez, Albert and Jagersand, Martin and Shao, Ling},
|
|
||||||
journal={arXiv preprint arXiv:2101.04704},
|
|
||||||
year={2021}
|
|
||||||
}
|
|
||||||
|
|
||||||
@InProceedings{Qin_2019_CVPR,
|
@InProceedings{Qin_2019_CVPR,
|
||||||
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Gao, Chao and Dehghan, Masood and Jagersand, Martin},
|
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Gao, Chao and Dehghan, Masood and Jagersand, Martin},
|
||||||
title = {BASNet: Boundary-Aware Salient Object Detection},
|
title = {BASNet: Boundary-Aware Salient Object Detection},
|
||||||
@ -56,3 +176,10 @@ This is the official repo for our new project DIS:
|
|||||||
month = {June},
|
month = {June},
|
||||||
year = {2019}
|
year = {2019}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@article{qin2021boundary,
|
||||||
|
title={Boundary-aware segmentation network for mobile and web applications},
|
||||||
|
author={Qin, Xuebin and Fan, Deng-Ping and Huang, Chenyang and Diagne, Cyril and Zhang, Zichen and Sant'Anna, Adri{\`a} Cabeza and Suarez, Albert and Jagersand, Martin and Shao, Ling},
|
||||||
|
journal={arXiv preprint arXiv:2101.04704},
|
||||||
|
year={2021}
|
||||||
|
}
|
BIN
figures/.DS_Store
vendored
Normal file
BIN
figures/3d-modeling.png
Normal file
After Width: | Height: | Size: 1.1 MiB |
BIN
figures/DIS5k-dataset-v1-sailship.png
Normal file
After Width: | Height: | Size: 1.0 MiB |
BIN
figures/categories.jpeg
Normal file
After Width: | Height: | Size: 928 KiB |
BIN
figures/complexities-qual.jpeg
Normal file
After Width: | Height: | Size: 667 KiB |
BIN
figures/dis5k-v1-sailship.jpeg
Normal file
After Width: | Height: | Size: 297 KiB |
BIN
figures/hce-metric.png
Normal file
After Width: | Height: | Size: 137 KiB |
BIN
figures/is-net.png
Normal file
After Width: | Height: | Size: 912 KiB |
BIN
figures/quan-comp.png
Normal file
After Width: | Height: | Size: 887 KiB |
BIN
figures/statistics.jpeg
Normal file
After Width: | Height: | Size: 525 KiB |
BIN
figures/video-3d.gif
Normal file
After Width: | Height: | Size: 14 MiB |