diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..3fcc1f4 Binary files /dev/null and b/.DS_Store differ diff --git a/DIS5K-Dataset-Terms-of-Use.pdf b/DIS5K-Dataset-Terms-of-Use.pdf new file mode 100644 index 0000000..47853c1 Binary files /dev/null and b/DIS5K-Dataset-Terms-of-Use.pdf differ diff --git a/IS-Net/__pycache__/basics.cpython-37.pyc b/IS-Net/__pycache__/basics.cpython-37.pyc new file mode 100644 index 0000000..f088d3d Binary files /dev/null and b/IS-Net/__pycache__/basics.cpython-37.pyc differ diff --git a/IS-Net/__pycache__/data_loader_cache.cpython-37.pyc b/IS-Net/__pycache__/data_loader_cache.cpython-37.pyc new file mode 100644 index 0000000..fb903e7 Binary files /dev/null and b/IS-Net/__pycache__/data_loader_cache.cpython-37.pyc differ diff --git a/IS-Net/basics.py b/IS-Net/basics.py new file mode 100644 index 0000000..36ab7cc --- /dev/null +++ b/IS-Net/basics.py @@ -0,0 +1,74 @@ +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '2' +from skimage import io, transform +import torch +import torchvision +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +import torch.optim as optim + +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image +import glob + +def mae_torch(pred,gt): + + h,w = gt.shape[0:2] + sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float()))) + maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4) + + return maeError + +def f1score_torch(pd,gt): + + # print(gt.shape) + gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels + + pp = pd[gt>128] + nn = pd[gt<=128] + + pp_hist =torch.histc(pp,bins=255,min=0,max=255) + nn_hist = torch.histc(nn,bins=255,min=0,max=255) + + + pp_hist_flip = torch.flipud(pp_hist) + nn_hist_flip = torch.flipud(nn_hist) + + pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0) + nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0) + + precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4)) + recall = (pp_hist_flip_cum)/(gtNum + 1e-4) + f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4) + + return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0])) + + +def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar): + + import time + tic = time.time() + + if(len(gt.shape)>2): + gt = gt[:,:,0] + + pre, rec, f1 = f1score_torch(pred,gt) + mae = mae_torch(pred,gt) + + + # hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ### + if(hypar["valid_out_dir"]!=""): + if(not os.path.exists(hypar["valid_out_dir"])): + os.mkdir(hypar["valid_out_dir"]) + dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx]) + if(not os.path.exists(dataset_folder)): + os.mkdir(dataset_folder) + io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8)) + print(valid_dataset.dataset["im_name"][idx]+".png") + print("time for evaluation : ", time.time()-tic) + + return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy() diff --git a/IS-Net/data_loader_cache.py b/IS-Net/data_loader_cache.py new file mode 100644 index 0000000..c99e6a8 --- /dev/null +++ b/IS-Net/data_loader_cache.py @@ -0,0 +1,380 @@ +## data loader +## Ackownledgement: +## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en) +## for his helps in implementing cache machanism of our DIS dataloader. +from __future__ import print_function, division + +import numpy as np +import random +from copy import deepcopy +import json +from tqdm import tqdm +from skimage import io +import os +from glob import glob + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +from torchvision.transforms.functional import normalize +import torch.nn.functional as F + +#### --------------------- DIS dataloader cache ---------------------#### + +def get_im_gt_name_dict(datasets, flag='valid'): + print("------------------------------", flag, "--------------------------------") + name_im_gt_list = [] + for i in range(len(datasets)): + print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---") + tmp_im_list, tmp_gt_list = [], [] + tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"]) + + # img_name_dict[im_dirs[i][0]] = tmp_im_list + print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list)) + + if(datasets[i]["gt_dir"]==""): + print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found') + else: + tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list] + # lbl_name_dict[im_dirs[i][0]] = tmp_gt_list + print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list)) + + + if flag=="train": ## combine multiple training sets into one dataset + if len(name_im_gt_list)==0: + name_im_gt_list.append({"dataset_name":datasets[i]["name"], + "im_path":tmp_im_list, + "gt_path":tmp_gt_list, + "im_ext":datasets[i]["im_ext"], + "gt_ext":datasets[i]["gt_ext"], + "cache_dir":datasets[i]["cache_dir"]}) + else: + name_im_gt_list[0]["dataset_name"] = name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"] + name_im_gt_list[0]["im_path"] = name_im_gt_list[0]["im_path"] + tmp_im_list + name_im_gt_list[0]["gt_path"] = name_im_gt_list[0]["gt_path"] + tmp_gt_list + if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png": + print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!") + exit() + name_im_gt_list[0]["im_ext"] = ".jpg" + name_im_gt_list[0]["gt_ext"] = ".png" + name_im_gt_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_list[0]["dataset_name"] + else: ## keep different validation or inference datasets as separate ones + name_im_gt_list.append({"dataset_name":datasets[i]["name"], + "im_path":tmp_im_list, + "gt_path":tmp_gt_list, + "im_ext":datasets[i]["im_ext"], + "gt_ext":datasets[i]["gt_ext"], + "cache_dir":datasets[i]["cache_dir"]}) + + return name_im_gt_list + +def create_dataloaders(name_im_gt_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False): + ## model="train": return one dataloader for training + ## model="valid": return a list of dataloaders for validation or testing + + gos_dataloaders = [] + gos_datasets = [] + # if(mode=="train"): + if(len(name_im_gt_list)==0): + return gos_dataloaders, gos_datasets + + num_workers_ = 1 + if(batch_size>1): + num_workers_ = 2 + if(batch_size>4): + num_workers_ = 4 + if(batch_size>8): + num_workers_ = 8 + + for i in range(0,len(name_im_gt_list)): + gos_dataset = GOSDatasetCache([name_im_gt_list[i]], + cache_size = cache_size, + cache_path = name_im_gt_list[i]["cache_dir"], + cache_boost = cache_boost, + transform = transforms.Compose(my_transforms)) + gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_)) + gos_datasets.append(gos_dataset) + + return gos_dataloaders, gos_datasets + +def im_reader(im_path): + return io.imread(im_path) + +def im_preprocess(im,size): + if len(im.shape) < 3: + im = im[:, :, np.newaxis] + if im.shape[2] == 1: + im = np.repeat(im, 3, axis=2) + im_tensor = torch.tensor(im, dtype=torch.float32) + im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1) + if(len(size)<2): + return im_tensor, im.shape[0:2] + else: + im_tensor = torch.unsqueeze(im_tensor,0) + im_tensor = F.upsample(im_tensor, size, mode="bilinear") + im_tensor = torch.squeeze(im_tensor,0) + + return im_tensor.type(torch.uint8), im.shape[0:2] + +def gt_preprocess(gt,size): + if len(gt.shape) > 2: + gt = gt[:, :, 0] + + gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0) + + if(len(size)<2): + return gt_tensor.type(torch.uint8), gt.shape[0:2] + else: + gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0) + gt_tensor = F.upsample(gt_tensor, size, mode="bilinear") + gt_tensor = torch.squeeze(gt_tensor,0) + + return gt_tensor.type(torch.uint8), gt.shape[0:2] + # return gt_tensor, gt.shape[0:2] + +class GOSRandomHFlip(object): + def __init__(self,prob=0.5): + self.prob = prob + def __call__(self,sample): + imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] + + # random horizontal flip + if random.random() >= self.prob: + image = torch.flip(image,dims=[2]) + label = torch.flip(label,dims=[2]) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} + +class GOSResize(object): + def __init__(self,size=[320,320]): + self.size = size + def __call__(self,sample): + imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] + + # import time + # start = time.time() + + image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0) + label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0) + + # print("time for resize: ", time.time()-start) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} + +class GOSRandomCrop(object): + def __init__(self,size=[288,288]): + self.size = size + def __call__(self,sample): + imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] + + h, w = image.shape[1:] + new_h, new_w = self.size + + top = np.random.randint(0, h - new_h) + left = np.random.randint(0, w - new_w) + + image = image[:,top:top+new_h,left:left+new_w] + label = label[:,top:top+new_h,left:left+new_w] + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} + + +class GOSNormalize(object): + def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]): + self.mean = mean + self.std = std + + def __call__(self,sample): + + imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] + image = normalize(image,self.mean,self.std) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} + + +class GOSDatasetCache(Dataset): + + def __init__(self, name_im_gt_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None): + + + self.cache_size = cache_size + self.cache_path = cache_path + self.cache_file_name = cache_file_name + self.cache_boost_name = "" + + self.cache_boost = cache_boost + # self.ims_npy = None + # self.gts_npy = None + + ## cache all the images and ground truth into a single pytorch tensor + self.ims_pt = None + self.gts_pt = None + + ## we will cache the npy as well regardless of the cache_boost + # if(self.cache_boost): + self.cache_boost_name = cache_file_name.split('.json')[0] + + self.transform = transform + + self.dataset = {} + + ## combine different datasets into one + dataset_names = [] + dt_name_list = [] # dataset name per image + im_name_list = [] # image name + im_path_list = [] # im path + gt_path_list = [] # gt path + im_ext_list = [] # im ext + gt_ext_list = [] # gt ext + for i in range(0,len(name_im_gt_list)): + dataset_names.append(name_im_gt_list[i]["dataset_name"]) + # dataset name repeated based on the number of images in this dataset + dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]]) + im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]]) + im_path_list.extend(name_im_gt_list[i]["im_path"]) + gt_path_list.extend(name_im_gt_list[i]["gt_path"]) + im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]]) + gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]]) + + + self.dataset["data_name"] = dt_name_list + self.dataset["im_name"] = im_name_list + self.dataset["im_path"] = im_path_list + self.dataset["ori_im_path"] = deepcopy(im_path_list) + self.dataset["gt_path"] = gt_path_list + self.dataset["ori_gt_path"] = deepcopy(gt_path_list) + self.dataset["im_shp"] = [] + self.dataset["gt_shp"] = [] + self.dataset["im_ext"] = im_ext_list + self.dataset["gt_ext"] = gt_ext_list + + + self.dataset["ims_pt_dir"] = "" + self.dataset["gts_pt_dir"] = "" + + self.dataset = self.manage_cache(dataset_names) + + def manage_cache(self,dataset_names): + if not os.path.exists(self.cache_path): # create the folder for cache + os.mkdir(self.cache_path) + cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size])) + if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache + return self.cache(cache_folder) + return self.load_cache(cache_folder) + + def cache(self,cache_folder): + os.mkdir(cache_folder) + cached_dataset = deepcopy(self.dataset) + + # ims_list = [] + # gts_list = [] + ims_pt_list = [] + gts_pt_list = [] + for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])): + + + im_id = cached_dataset["im_name"][i] + + im = im_reader(im_path) + im, im_shp = im_preprocess(im,self.cache_size) + im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt") + torch.save(im,im_cache_file) + + cached_dataset["im_path"][i] = im_cache_file + if(self.cache_boost): + ims_pt_list.append(torch.unsqueeze(im,0)) + # ims_list.append(im.cpu().data.numpy().astype(np.uint8)) + + + gt = im_reader(self.dataset["gt_path"][i]) + gt, gt_shp = gt_preprocess(gt,self.cache_size) + gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt") + torch.save(gt,gt_cache_file) + cached_dataset["gt_path"][i] = gt_cache_file + if(self.cache_boost): + gts_pt_list.append(torch.unsqueeze(gt,0)) + # gts_list.append(gt.cpu().data.numpy().astype(np.uint8)) + + # im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt") + # torch.save(gt_shp, shp_cache_file) + cached_dataset["im_shp"].append(im_shp) + # self.dataset["im_shp"].append(im_shp) + + # shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt") + # torch.save(gt_shp, shp_cache_file) + cached_dataset["gt_shp"].append(gt_shp) + # self.dataset["gt_shp"].append(gt_shp) + + if(self.cache_boost): + cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt') + cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt') + self.ims_pt = torch.cat(ims_pt_list,dim=0) + self.gts_pt = torch.cat(gts_pt_list,dim=0) + torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"]) + torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"]) + + try: + json_file = open(os.path.join(cache_folder, self.cache_file_name),"w") + json.dump(cached_dataset, json_file) + json_file.close() + except Exception: + raise FileNotFoundError("Cannot create JSON") + return cached_dataset + + def load_cache(self, cache_folder): + json_file = open(os.path.join(cache_folder,self.cache_file_name),"r") + dataset = json.load(json_file) + json_file.close() + ## if cache_boost is true, we will load the image npy and ground truth npy into the RAM + ## otherwise the pytorch tensor will be loaded + if(self.cache_boost): + # self.ims_npy = np.load(dataset["ims_npy_dir"]) + # self.gts_npy = np.load(dataset["gts_npy_dir"]) + self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu') + self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu') + return dataset + + def __len__(self): + return len(self.dataset["im_path"]) + + def __getitem__(self, idx): + + im = None + gt = None + if(self.cache_boost and self.ims_pt is not None): + + # start = time.time() + im = self.ims_pt[idx]#.type(torch.float32) + gt = self.gts_pt[idx]#.type(torch.float32) + # print(idx, 'time for pt loading: ', time.time()-start) + + else: + # import time + # start = time.time() + # print("tensor***") + im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:])) + im = torch.load(im_pt_path)#(self.dataset["im_path"][idx]) + gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:])) + gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx]) + # print(idx,'time for tensor loading: ', time.time()-start) + + + im_shp = self.dataset["im_shp"][idx] + # print("time for loading im and gt: ", time.time()-start) + + # start_time = time.time() + im = torch.divide(im,255.0) + gt = torch.divide(gt,255.0) + # print(idx, 'time for normalize torch divide: ', time.time()-start_time) + + sample = { + "imidx": torch.from_numpy(np.array(idx)), + "image": im, + "label": gt, + "shape": torch.from_numpy(np.array(im_shp)), + } + + if self.transform: + sample = self.transform(sample) + + return sample diff --git a/IS-Net/hce_metric_main.py b/IS-Net/hce_metric_main.py new file mode 100644 index 0000000..1b102e6 --- /dev/null +++ b/IS-Net/hce_metric_main.py @@ -0,0 +1,188 @@ +## hce_metric.py +import numpy as np +from skimage import io +import matplotlib.pyplot as plt +import cv2 as cv +from skimage.morphology import skeletonize +from skimage.morphology import erosion, dilation, disk +from skimage.measure import label + +import os +import sys +from tqdm import tqdm +from glob import glob +import pickle as pkl + +def filter_bdy_cond(bdy_, mask, cond): + + cond = cv.dilate(cond.astype(np.uint8),disk(1)) + labels = label(mask) # find the connected regions + lbls = np.unique(labels) # the indices of the connected regions + indep = np.ones(lbls.shape[0]) # the label of each connected regions + indep[0] = 0 # 0 indicate the background region + + boundaries = [] + h,w = cond.shape[0:2] + ind_map = np.zeros((h,w)) + indep_cnt = 0 + + for i in range(0,len(bdy_)): + tmp_bdies = [] + tmp_bdy = [] + for j in range(0,bdy_[i].shape[0]): + r, c = bdy_[i][j,0,1],bdy_[i][j,0,0] + + if(np.sum(cond[r,c])==0 or ind_map[r,c]!=0): + if(len(tmp_bdy)>0): + tmp_bdies.append(tmp_bdy) + tmp_bdy = [] + continue + tmp_bdy.append([c,r]) + ind_map[r,c] = ind_map[r,c] + 1 + indep[labels[r,c]] = 0 # indicates part of the boundary of this region needs human correction + if(len(tmp_bdy)>0): + tmp_bdies.append(tmp_bdy) + + # check if the first and the last boundaries are connected + # if yes, invert the first boundary and attach it after the last boundary + if(len(tmp_bdies)>1): + first_x, first_y = tmp_bdies[0][0] + last_x, last_y = tmp_bdies[-1][-1] + if((abs(first_x-last_x)==1 and first_y==last_y) or + (first_x==last_x and abs(first_y-last_y)==1) or + (abs(first_x-last_x)==1 and abs(first_y-last_y)==1) + ): + tmp_bdies[-1].extend(tmp_bdies[0][::-1]) + del tmp_bdies[0] + + for k in range(0,len(tmp_bdies)): + tmp_bdies[k] = np.array(tmp_bdies[k])[:,np.newaxis,:] + if(len(tmp_bdies)>0): + boundaries.extend(tmp_bdies) + + return boundaries, np.sum(indep) + +# this function approximate each boundary by DP algorithm +# https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm +def approximate_RDP(boundaries,epsilon=1.0): + + boundaries_ = [] + boundaries_len_ = [] + pixel_cnt_ = 0 + + # polygon approximate of each boundary + for i in range(0,len(boundaries)): + boundaries_.append(cv.approxPolyDP(boundaries[i],epsilon,False)) + + # count the control points number of each boundary and the total control points number of all the boundaries + for i in range(0,len(boundaries_)): + boundaries_len_.append(len(boundaries_[i])) + pixel_cnt_ = pixel_cnt_ + len(boundaries_[i]) + + return boundaries_, boundaries_len_, pixel_cnt_ + + +def relax_HCE(gt, rs, gt_ske, relax=5, epsilon=2.0): + # print("max(gt_ske): ", np.amax(gt_ske)) + # gt_ske = gt_ske>128 + # print("max(gt_ske): ", np.amax(gt_ske)) + + # Binarize gt + if(len(gt.shape)>2): + gt = gt[:,:,0] + + epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0 + gt = (gt>epsilon_gt).astype(np.uint8) + + # Binarize rs + if(len(rs.shape)>2): + rs = rs[:,:,0] + epsilon_rs = 128#(np.amin(rs)+np.amax(rs))/2.0 + rs = (rs>epsilon_rs).astype(np.uint8) + + Union = np.logical_or(gt,rs) + TP = np.logical_and(gt,rs) + FP = rs - TP + FN = gt - TP + + # relax the Union of gt and rs + Union_erode = Union.copy() + Union_erode = cv.erode(Union_erode.astype(np.uint8),disk(1),iterations=relax) + + # --- get the relaxed False Positive regions for computing the human efforts in correcting them --- + FP_ = np.logical_and(FP,Union_erode) # get the relaxed FP + for i in range(0,relax): + FP_ = cv.dilate(FP_.astype(np.uint8),disk(1)) + FP_ = np.logical_and(FP_, 1-np.logical_or(TP,FN)) + FP_ = np.logical_and(FP, FP_) + + # --- get the relaxed False Negative regions for computing the human efforts in correcting them --- + FN_ = np.logical_and(FN,Union_erode) # preserve the structural components of FN + ## recover the FN, where pixels are not close to the TP borders + for i in range(0,relax): + FN_ = cv.dilate(FN_.astype(np.uint8),disk(1)) + FN_ = np.logical_and(FN_,1-np.logical_or(TP,FP)) + FN_ = np.logical_and(FN,FN_) + FN_ = np.logical_or(FN_, np.logical_xor(gt_ske,np.logical_and(TP,gt_ske))) # preserve the structural components of FN + + ## 2. =============Find exact polygon control points and independent regions============== + ## find contours from FP_ + ctrs_FP, hier_FP = cv.findContours(FP_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE) + ## find control points and independent regions for human correction + bdies_FP, indep_cnt_FP = filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_)) + ## find contours from FN_ + ctrs_FN, hier_FN = cv.findContours(FN_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE) + ## find control points and independent regions for human correction + bdies_FN, indep_cnt_FN = filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP,FP_),FN_)) + + poly_FP, poly_FP_len, poly_FP_point_cnt = approximate_RDP(bdies_FP,epsilon=epsilon) + poly_FN, poly_FN_len, poly_FN_point_cnt = approximate_RDP(bdies_FN,epsilon=epsilon) + + return poly_FP_point_cnt, indep_cnt_FP, poly_FN_point_cnt, indep_cnt_FN + +def compute_hce(pred_root,gt_root,gt_ske_root): + + gt_name_list = glob(pred_root+'/*.png') + gt_name_list = sorted([x.split('/')[-1] for x in gt_name_list]) + + hces = [] + for gt_name in tqdm(gt_name_list, total=len(gt_name_list)): + gt_path = os.path.join(gt_root, gt_name) + pred_path = os.path.join(pred_root, gt_name) + + gt = cv.imread(gt_path, cv.IMREAD_GRAYSCALE) + pred = cv.imread(pred_path, cv.IMREAD_GRAYSCALE) + + ske_path = os.path.join(gt_ske_root,gt_name) + if os.path.exists(ske_path): + ske = cv.imread(ske_path,cv.IMREAD_GRAYSCALE) + ske = ske>128 + else: + ske = skeletonize(gt>128) + + FP_points, FP_indep, FN_points, FN_indep = relax_HCE(gt, pred,ske) + print(gt_path.split('/')[-1],FP_points, FP_indep, FN_points, FN_indep) + hces.append([FP_points, FP_indep, FN_points, FN_indep, FP_points+FP_indep+FN_points+FN_indep]) + + hce_metric ={'names': gt_name_list, + 'hces': hces} + + + file_metric = open(pred_root+'/hce_metric.pkl','wb') + pkl.dump(hce_metric,file_metric) + # file_metrics.write(cmn_metrics) + file_metric.close() + + return np.mean(np.array(hces)[:,-1]) + +def main(): + + gt_root = "../DIS5K/DIS-VD/gt" + gt_ske_root = "" + pred_root = "../Results/isnet(ours)/DIS-VD" + + print("The average HCE metric: ", compute_hce(pred_root,gt_root,gt_ske_root)) + + +if __name__ == '__main__': + main() diff --git a/IS-Net/models/__init__.py b/IS-Net/models/__init__.py new file mode 100644 index 0000000..74aba2b --- /dev/null +++ b/IS-Net/models/__init__.py @@ -0,0 +1 @@ +from models.isnet import ISNetGTEncoder, ISNetDIS diff --git a/IS-Net/models/__pycache__/__init__.cpython-37.pyc b/IS-Net/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..23387c5 Binary files /dev/null and b/IS-Net/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/IS-Net/models/__pycache__/u2netfast.cpython-37.pyc b/IS-Net/models/__pycache__/u2netfast.cpython-37.pyc new file mode 100644 index 0000000..a13adc7 Binary files /dev/null and b/IS-Net/models/__pycache__/u2netfast.cpython-37.pyc differ diff --git a/IS-Net/models/isnet.py b/IS-Net/models/isnet.py new file mode 100644 index 0000000..a3937b6 --- /dev/null +++ b/IS-Net/models/isnet.py @@ -0,0 +1,614 @@ +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F + + +bce_loss = nn.BCELoss(size_average=True) +def muti_loss_fusion(preds, target): + loss0 = 0.0 + loss = 0.0 + + for i in range(0,len(preds)): + # print("i: ", i, preds[i].shape) + if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]): + # tmp_target = _upsample_like(target,preds[i]) + tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True) + loss = loss + bce_loss(preds[i],tmp_target) + else: + loss = loss + bce_loss(preds[i],target) + if(i==0): + loss0 = loss + return loss0, loss + +fea_loss = nn.MSELoss(size_average=True) +kl_loss = nn.KLDivLoss(size_average=True) +l1_loss = nn.L1Loss(size_average=True) +smooth_l1_loss = nn.SmoothL1Loss(size_average=True) +def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'): + loss0 = 0.0 + loss = 0.0 + + for i in range(0,len(preds)): + # print("i: ", i, preds[i].shape) + if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]): + # tmp_target = _upsample_like(target,preds[i]) + tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True) + loss = loss + bce_loss(preds[i],tmp_target) + else: + loss = loss + bce_loss(preds[i],target) + if(i==0): + loss0 = loss + + for i in range(0,len(dfs)): + if(mode=='MSE'): + loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints + # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item()) + elif(mode=='KL'): + loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)) + # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item()) + elif(mode=='MAE'): + loss = loss + l1_loss(dfs[i],fs[i]) + # print("ls_loss: ", l1_loss(dfs[i],fs[i])) + elif(mode=='SmoothL1'): + loss = loss + smooth_l1_loss(dfs[i],fs[i]) + # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item()) + + return loss0, loss + +class REBNCONV(nn.Module): + def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1): + super(REBNCONV,self).__init__() + + self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self,x): + + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src,tar): + + src = F.upsample(src,size=tar.shape[2:],mode='bilinear') + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7,self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) + hx6dup = _upsample_like(hx6d,hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + + hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-5 ### +class RSU5(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4 ### +class RSU4(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4F ### +class RSU4F(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) + hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__(self, in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1): + super(myrebnconv,self).__init__() + + self.conv = nn.Conv2d(in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self,x): + return self.rl(self.bn(self.conv(x))) + + +class ISNetGTEncoder(nn.Module): + + def __init__(self,in_ch=1,out_ch=1): + super(ISNetGTEncoder,self).__init__() + + self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1) + + self.stage1 = RSU7(16,16,64) + self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage2 = RSU6(64,16,64) + self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage3 = RSU5(64,32,128) + self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage4 = RSU4(128,32,256) + self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage5 = RSU4F(256,64,512) + self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage6 = RSU4F(512,64,512) + + + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(128,out_ch,3,padding=1) + self.side4 = nn.Conv2d(256,out_ch,3,padding=1) + self.side5 = nn.Conv2d(512,out_ch,3,padding=1) + self.side6 = nn.Conv2d(512,out_ch,3,padding=1) + + def compute_loss_max(self, preds, targets, fs): + + return muti_loss_fusion_max(preds, targets,fs) + + def compute_loss(self, preds, targets): + + return muti_loss_fusion(preds,targets) + + def forward(self,x): + + hx = x + + hxin = self.conv_in(hx) + # hx = self.pool_in(hxin) + + #stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + #stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + #stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + #stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + #stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + #stage 6 + hx6 = self.stage6(hx) + + + #side output + d1 = self.side1(hx1) + d1 = _upsample_like(d1,x) + + d2 = self.side2(hx2) + d2 = _upsample_like(d2,x) + + d3 = self.side3(hx3) + d3 = _upsample_like(d3,x) + + d4 = self.side4(hx4) + d4 = _upsample_like(d4,x) + + d5 = self.side5(hx5) + d5 = _upsample_like(d5,x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6,x) + + # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + + return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6] + +class ISNetDIS(nn.Module): + + def __init__(self,in_ch=3,out_ch=1): + super(ISNetDIS,self).__init__() + + self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1) + self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage1 = RSU7(64,32,64) + self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage2 = RSU6(64,32,128) + self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage3 = RSU5(128,64,256) + self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage4 = RSU4(256,128,512) + self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage5 = RSU4F(512,256,512) + self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage6 = RSU4F(512,256,512) + + # decoder + self.stage5d = RSU4F(1024,256,512) + self.stage4d = RSU4(1024,128,256) + self.stage3d = RSU5(512,64,128) + self.stage2d = RSU6(256,32,64) + self.stage1d = RSU7(128,16,64) + + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(128,out_ch,3,padding=1) + self.side4 = nn.Conv2d(256,out_ch,3,padding=1) + self.side5 = nn.Conv2d(512,out_ch,3,padding=1) + self.side6 = nn.Conv2d(512,out_ch,3,padding=1) + + # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) + + def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'): + + # return muti_loss_fusion(preds,targets) + return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode) + + def compute_loss(self, preds, targets): + + # return muti_loss_fusion(preds,targets) + return muti_loss_fusion(preds, targets) + + def forward(self,x): + + hx = x + + hxin = self.conv_in(hx) + hx = self.pool_in(hxin) + + #stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + #stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + #stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + #stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + #stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + #stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6,hx5) + + #-------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) + + + #side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1,x) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2,x) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3,x) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4,x) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5,x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6,x) + + # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + + return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6] diff --git a/IS-Net/pytorch18.yml b/IS-Net/pytorch18.yml new file mode 100644 index 0000000..9e5d500 --- /dev/null +++ b/IS-Net/pytorch18.yml @@ -0,0 +1,92 @@ +name: pytorch18 +channels: + - conda-forge + - anaconda + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - blas=1.0=mkl + - brotli=1.0.9=he6710b0_2 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2022.2.1=h06a4308_0 + - certifi=2021.10.8=py37h06a4308_2 + - cloudpickle=2.0.0=pyhd3eb1b0_0 + - colorama=0.4.4=pyhd3eb1b0_0 + - cudatoolkit=10.2.89=hfd86e86_1 + - cycler=0.11.0=pyhd3eb1b0_0 + - cytoolz=0.11.0=py37h7b6447c_0 + - dask-core=2021.10.0=pyhd3eb1b0_0 + - ffmpeg=4.3=hf484d3e_0 + - fonttools=4.25.0=pyhd3eb1b0_0 + - freetype=2.11.0=h70c0345_0 + - fsspec=2022.2.0=pyhd3eb1b0_0 + - gmp=6.2.1=h2531618_2 + - gnutls=3.6.15=he1e5248_0 + - imageio=2.9.0=pyhd3eb1b0_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.2=py37h295c915_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libgomp=9.3.0=h5101ec6_17 + - libiconv=1.15=h63c8f33_5 + - libidn2=2.3.2=h7f8727e_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.2.0=h85742a9_0 + - libunistring=0.9.10=h27cfd23_0 + - libuv=1.40.0=h7b6447c_0 + - libwebp-base=1.2.2=h7f8727e_0 + - locket=0.2.1=py37h06a4308_2 + - lz4-c=1.9.3=h295c915_1 + - matplotlib-base=3.5.1=py37ha18d171_1 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py37h7f8727e_0 + - mkl_fft=1.3.1=py37hd3c417c_0 + - mkl_random=1.2.2=py37h51133e4_0 + - munkres=1.1.4=py_0 + - ncurses=6.3=h7f8727e_2 + - nettle=3.7.3=hbbd107a_1 + - networkx=2.6.3=pyhd3eb1b0_0 + - ninja=1.10.2=py37hd09550d_3 + - numpy=1.21.2=py37h20f2e39_0 + - numpy-base=1.21.2=py37h79a1101_0 + - olefile=0.46=py37_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1n=h7f8727e_0 + - packaging=21.3=pyhd3eb1b0_0 + - partd=1.2.0=pyhd3eb1b0_1 + - pillow=8.0.0=py37h9a89aac_0 + - pip=21.2.2=py37h06a4308_0 + - pyparsing=3.0.4=pyhd3eb1b0_0 + - python=3.7.11=h12debd9_0 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0 + - pywavelets=1.1.1=py37h7b6447c_2 + - pyyaml=6.0=py37h7f8727e_1 + - readline=8.1.2=h7f8727e_1 + - scikit-image=0.15.0=py37hb3f55d8_2 + - scipy=1.7.3=py37hc147768_0 + - setuptools=58.0.4=py37h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.38.0=hc218d9a_0 + - tk=8.6.11=h1ccaba5_0 + - toolz=0.11.2=pyhd3eb1b0_0 + - torchaudio=0.8.0=py37 + - torchvision=0.9.0=py37_cu102 + - tqdm=4.63.0=pyhd8ed1ab_0 + - typing_extensions=3.10.0.2=pyh06a4308_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.11=h7f8727e_4 + - zstd=1.4.9=haebb681_0 +prefix: /home/solar/anaconda3/envs/pytorch18 diff --git a/IS-Net/requirements.txt b/IS-Net/requirements.txt new file mode 100644 index 0000000..1dc3af5 --- /dev/null +++ b/IS-Net/requirements.txt @@ -0,0 +1,87 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +_libgcc_mutex=0.1=main +_openmp_mutex=4.5=1_gnu +blas=1.0=mkl +brotli=1.0.9=he6710b0_2 +bzip2=1.0.8=h7b6447c_0 +ca-certificates=2022.2.1=h06a4308_0 +certifi=2021.10.8=py37h06a4308_2 +cloudpickle=2.0.0=pyhd3eb1b0_0 +colorama=0.4.4=pyhd3eb1b0_0 +cudatoolkit=10.2.89=hfd86e86_1 +cycler=0.11.0=pyhd3eb1b0_0 +cytoolz=0.11.0=py37h7b6447c_0 +dask-core=2021.10.0=pyhd3eb1b0_0 +ffmpeg=4.3=hf484d3e_0 +fonttools=4.25.0=pyhd3eb1b0_0 +freetype=2.11.0=h70c0345_0 +fsspec=2022.2.0=pyhd3eb1b0_0 +gmp=6.2.1=h2531618_2 +gnutls=3.6.15=he1e5248_0 +imageio=2.9.0=pyhd3eb1b0_0 +intel-openmp=2021.4.0=h06a4308_3561 +jpeg=9b=h024ee3a_2 +kiwisolver=1.3.2=py37h295c915_0 +lame=3.100=h7b6447c_0 +lcms2=2.12=h3be6417_0 +ld_impl_linux-64=2.35.1=h7274673_9 +libffi=3.3=he6710b0_2 +libgcc-ng=9.3.0=h5101ec6_17 +libgfortran-ng=7.5.0=ha8ba4b0_17 +libgfortran4=7.5.0=ha8ba4b0_17 +libgomp=9.3.0=h5101ec6_17 +libiconv=1.15=h63c8f33_5 +libidn2=2.3.2=h7f8727e_0 +libpng=1.6.37=hbc83047_0 +libstdcxx-ng=9.3.0=hd4cf53a_17 +libtasn1=4.16.0=h27cfd23_0 +libtiff=4.2.0=h85742a9_0 +libunistring=0.9.10=h27cfd23_0 +libuv=1.40.0=h7b6447c_0 +libwebp-base=1.2.2=h7f8727e_0 +locket=0.2.1=py37h06a4308_2 +lz4-c=1.9.3=h295c915_1 +matplotlib-base=3.5.1=py37ha18d171_1 +mkl=2021.4.0=h06a4308_640 +mkl-service=2.4.0=py37h7f8727e_0 +mkl_fft=1.3.1=py37hd3c417c_0 +mkl_random=1.2.2=py37h51133e4_0 +munkres=1.1.4=py_0 +ncurses=6.3=h7f8727e_2 +nettle=3.7.3=hbbd107a_1 +networkx=2.6.3=pyhd3eb1b0_0 +ninja=1.10.2=py37hd09550d_3 +numpy=1.21.2=py37h20f2e39_0 +numpy-base=1.21.2=py37h79a1101_0 +olefile=0.46=py37_0 +openh264=2.1.1=h4ff587b_0 +openssl=1.1.1n=h7f8727e_0 +packaging=21.3=pyhd3eb1b0_0 +partd=1.2.0=pyhd3eb1b0_1 +pillow=8.0.0=py37h9a89aac_0 +pip=21.2.2=py37h06a4308_0 +pyparsing=3.0.4=pyhd3eb1b0_0 +python=3.7.11=h12debd9_0 +python-dateutil=2.8.2=pyhd3eb1b0_0 +pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0 +pywavelets=1.1.1=py37h7b6447c_2 +pyyaml=6.0=py37h7f8727e_1 +readline=8.1.2=h7f8727e_1 +scikit-image=0.15.0=py37hb3f55d8_2 +scipy=1.7.3=py37hc147768_0 +setuptools=58.0.4=py37h06a4308_0 +six=1.16.0=pyhd3eb1b0_1 +sqlite=3.38.0=hc218d9a_0 +tk=8.6.11=h1ccaba5_0 +toolz=0.11.2=pyhd3eb1b0_0 +torchaudio=0.8.0=py37 +torchvision=0.9.0=py37_cu102 +tqdm=4.63.0=pyhd8ed1ab_0 +typing_extensions=3.10.0.2=pyh06a4308_0 +wheel=0.37.1=pyhd3eb1b0_0 +xz=5.2.5=h7b6447c_0 +yaml=0.2.5=h7b6447c_0 +zlib=1.2.11=h7f8727e_4 +zstd=1.4.9=haebb681_0 diff --git a/IS-Net/train_valid_inference_main.py b/IS-Net/train_valid_inference_main.py new file mode 100644 index 0000000..a377597 --- /dev/null +++ b/IS-Net/train_valid_inference_main.py @@ -0,0 +1,713 @@ +import os +import time +import numpy as np +from skimage import io +import time + +import torch, gc +import torch.nn as nn +from torch.autograd import Variable +import torch.optim as optim +import torch.nn.functional as F + +from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache, +from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch, +from models import * + +def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000): + + # train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list, + # cache_size = hypar["cache_size"], + # cache_boost = hypar["cache_boost_train"], + # my_transforms = [ + # GOSRandomHFlip(), + # # GOSResize(hypar["input_size"]), + # # GOSRandomCrop(hypar["crop_size"]), + # GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), + # ], + # batch_size = hypar["batch_size_train"], + # shuffle = True) + + torch.manual_seed(hypar["seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed(hypar["seed"]) + + print("define gt encoder ...") + net = ISNetGTEncoder() #UNETGTENCODERCombine() + ## load the existing model gt encoder + if(hypar["gt_encoder_model"]!=""): + model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"] + if torch.cuda.is_available(): + net.load_state_dict(torch.load(model_path)) + net.cuda() + else: + net.load_state_dict(torch.load(model_path,map_location="cpu")) + print("gt encoder restored from the saved weights ...") + return net ############ + + if torch.cuda.is_available(): + net.cuda() + + print("--- define optimizer for GT Encoder---") + optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + + model_path = hypar["model_path"] + model_save_fre = hypar["model_save_fre"] + max_ite = hypar["max_ite"] + batch_size_train = hypar["batch_size_train"] + batch_size_valid = hypar["batch_size_valid"] + + if(not os.path.exists(model_path)): + os.mkdir(model_path) + + ite_num = hypar["start_ite"] # count the total iteration number + ite_num4val = 0 # + running_loss = 0.0 # count the toal loss + running_tar_loss = 0.0 # count the target output loss + last_f1 = [0 for x in range(len(valid_dataloaders))] + + train_num = train_datasets[0].__len__() + + net.train() + + start_last = time.time() + gos_dataloader = train_dataloaders[0] + epoch_num = hypar["max_epoch_num"] + notgood_cnt = 0 + for epoch in range(epoch_num): ## set the epoch num as 100000 + + for i, data in enumerate(gos_dataloader): + + if(ite_num >= max_ite): + print("Training Reached the Maximal Iteration Number ", max_ite) + exit() + + # start_read = time.time() + ite_num = ite_num + 1 + ite_num4val = ite_num4val + 1 + + # get the inputs + labels = data['label'] + + if(hypar["model_digit"]=="full"): + labels = labels.type(torch.FloatTensor) + else: + labels = labels.type(torch.HalfTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + labels_v = Variable(labels.cuda(), requires_grad=False) + else: + labels_v = Variable(labels, requires_grad=False) + + # print("time lapse for data preparation: ", time.time()-start_read, ' s') + + # y zero the parameter gradients + start_inf_loss_back = time.time() + optimizer.zero_grad() + + ds, fs = net(labels_v)#net(inputs_v) + loss2, loss = net.compute_loss(ds, labels_v) + + loss.backward() + optimizer.step() + + running_loss += loss.item() + running_tar_loss += loss2.item() + + # del outputs, loss + del ds, loss2, loss + end_inf_loss_back = time.time()-start_inf_loss_back + + print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % ( + epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back)) + start_last = time.time() + + if ite_num % model_save_fre == 0: # validate every 2000 iterations + notgood_cnt += 1 + # net.eval() + # tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch) + tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch) + + net.train() # resume train + + tmp_out = 0 + print("last_f1:",last_f1) + print("tmp_f1:",tmp_f1) + for fi in range(len(last_f1)): + if(tmp_f1[fi]>last_f1[fi]): + tmp_out = 1 + print("tmp_out:",tmp_out) + if(tmp_out): + notgood_cnt = 0 + last_f1 = tmp_f1 + tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1] + tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae] + maxf1 = '_'.join(tmp_f1_str) + meanM = '_'.join(tmp_mae_str) + # .cpu().detach().numpy() + model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\ + "_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\ + "_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\ + "_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\ + "_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \ + "_maxF1_" + maxf1 + \ + "_mae_" + meanM + \ + "_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth" + torch.save(net.state_dict(), model_path + model_name) + + running_loss = 0.0 + running_tar_loss = 0.0 + ite_num4val = 0 + + if(tmp_f1[0]>0.99): + print("GT encoder is well-trained and obtained...") + return net + + if(notgood_cnt >= hypar["early_stop"]): + print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !") + exit() + + print("Training Reaches The Maximum Epoch Number") + return net + +def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0): + net.eval() + print("Validating...") + epoch_num = hypar["max_epoch_num"] + + val_loss = 0.0 + tar_loss = 0.0 + + + tmp_f1 = [] + tmp_mae = [] + tmp_time = [] + + start_valid = time.time() + for k in range(len(valid_dataloaders)): + + valid_dataloader = valid_dataloaders[k] + valid_dataset = valid_datasets[k] + + val_num = valid_dataset.__len__() + mybins = np.arange(0,256) + PRE = np.zeros((val_num,len(mybins)-1)) + REC = np.zeros((val_num,len(mybins)-1)) + F1 = np.zeros((val_num,len(mybins)-1)) + MAE = np.zeros((val_num)) + + val_cnt = 0.0 + for i_val, data_val in enumerate(valid_dataloader): + + # imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'] + imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape'] + + if(hypar["model_digit"]=="full"): + labels_val = labels_val.type(torch.FloatTensor) + else: + labels_val = labels_val.type(torch.HalfTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + labels_val_v = Variable(labels_val.cuda(), requires_grad=False) + else: + labels_val_v = Variable(labels_val,requires_grad=False) + + t_start = time.time() + ds_val = net(labels_val_v)[0] + t_end = time.time()-t_start + tmp_time.append(t_end) + + # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v) + loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v) + + # compute F measure + for t in range(hypar["batch_size_valid"]): + val_cnt = val_cnt + 1.0 + print("num of val: ", val_cnt) + i_test = imidx_val[t].data.numpy() + + pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W + + ## recover the prediction spatial size to the orignal image size + pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear')) + + ma = torch.max(pred_val) + mi = torch.min(pred_val) + pred_val = (pred_val-mi)/(ma-mi) # max = 1 + # pred_val = normPRED(pred_val) + + gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 + with torch.no_grad(): + gt = torch.tensor(gt).cuda() + + pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) + + PRE[i_test,:]=pre + REC[i_test,:] = rec + F1[i_test,:] = f1 + MAE[i_test] = mae + + del ds_val, gt + gc.collect() + torch.cuda.empty_cache() + + # if(loss_val.data[0]>1): + val_loss += loss_val.item()#data[0] + tar_loss += loss2_val.item()#data[0] + + print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end)) + + del loss2_val, loss_val + + print('============================') + PRE_m = np.mean(PRE,0) + REC_m = np.mean(REC,0) + f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8) + # print('--------------:', np.mean(f1_m)) + tmp_f1.append(np.amax(f1_m)) + tmp_mae.append(np.mean(MAE)) + print("The max F1 Score: %f"%(np.max(f1_m))) + print("MAE: ", np.mean(MAE)) + + # print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid)) + + return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time + +def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000): + + if hypar["interm_sup"]: + print("Get the gt encoder ...") + featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val) + ## freeze the weights of gt encoder + for param in featurenet.parameters(): + param.requires_grad=False + + + model_path = hypar["model_path"] + model_save_fre = hypar["model_save_fre"] + max_ite = hypar["max_ite"] + batch_size_train = hypar["batch_size_train"] + batch_size_valid = hypar["batch_size_valid"] + + if(not os.path.exists(model_path)): + os.mkdir(model_path) + + ite_num = hypar["start_ite"] # count the toal iteration number + ite_num4val = 0 # + running_loss = 0.0 # count the toal loss + running_tar_loss = 0.0 # count the target output loss + last_f1 = [0 for x in range(len(valid_dataloaders))] + + train_num = train_datasets[0].__len__() + + net.train() + + start_last = time.time() + gos_dataloader = train_dataloaders[0] + epoch_num = hypar["max_epoch_num"] + notgood_cnt = 0 + for epoch in range(epoch_num): ## set the epoch num as 100000 + + for i, data in enumerate(gos_dataloader): + + if(ite_num >= max_ite): + print("Training Reached the Maximal Iteration Number ", max_ite) + exit() + + # start_read = time.time() + ite_num = ite_num + 1 + ite_num4val = ite_num4val + 1 + + # get the inputs + inputs, labels = data['image'], data['label'] + + if(hypar["model_digit"]=="full"): + inputs = inputs.type(torch.FloatTensor) + labels = labels.type(torch.FloatTensor) + else: + inputs = inputs.type(torch.HalfTensor) + labels = labels.type(torch.HalfTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False) + else: + inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False) + + # print("time lapse for data preparation: ", time.time()-start_read, ' s') + + # y zero the parameter gradients + start_inf_loss_back = time.time() + optimizer.zero_grad() + + if hypar["interm_sup"]: + # forward + backward + optimize + ds,dfs = net(inputs_v) + _,fs = featurenet(labels_v) ## extract the gt encodings + loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE') + else: + # forward + backward + optimize + ds,_ = net(inputs_v) + loss2, loss = muti_loss_fusion(ds, labels_v) + + loss.backward() + optimizer.step() + + # # print statistics + running_loss += loss.item() + running_tar_loss += loss2.item() + + # del outputs, loss + del ds, loss2, loss + end_inf_loss_back = time.time()-start_inf_loss_back + + print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % ( + epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back)) + start_last = time.time() + + if ite_num % model_save_fre == 0: # validate every 2000 iterations + notgood_cnt += 1 + net.eval() + tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch) + net.train() # resume train + + tmp_out = 0 + print("last_f1:",last_f1) + print("tmp_f1:",tmp_f1) + for fi in range(len(last_f1)): + if(tmp_f1[fi]>last_f1[fi]): + tmp_out = 1 + print("tmp_out:",tmp_out) + if(tmp_out): + notgood_cnt = 0 + last_f1 = tmp_f1 + tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1] + tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae] + maxf1 = '_'.join(tmp_f1_str) + meanM = '_'.join(tmp_mae_str) + # .cpu().detach().numpy() + model_name = "/gpu_itr_"+str(ite_num)+\ + "_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\ + "_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\ + "_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\ + "_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \ + "_maxF1_" + maxf1 + \ + "_mae_" + meanM + \ + "_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth" + torch.save(net.state_dict(), model_path + model_name) + + running_loss = 0.0 + running_tar_loss = 0.0 + ite_num4val = 0 + + if(notgood_cnt >= hypar["early_stop"]): + print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !") + exit() + + print("Training Reaches The Maximum Epoch Number") + +def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0): + net.eval() + print("Validating...") + epoch_num = hypar["max_epoch_num"] + + val_loss = 0.0 + tar_loss = 0.0 + val_cnt = 0.0 + + tmp_f1 = [] + tmp_mae = [] + tmp_time = [] + + start_valid = time.time() + + for k in range(len(valid_dataloaders)): + + valid_dataloader = valid_dataloaders[k] + valid_dataset = valid_datasets[k] + + val_num = valid_dataset.__len__() + mybins = np.arange(0,256) + PRE = np.zeros((val_num,len(mybins)-1)) + REC = np.zeros((val_num,len(mybins)-1)) + F1 = np.zeros((val_num,len(mybins)-1)) + MAE = np.zeros((val_num)) + + for i_val, data_val in enumerate(valid_dataloader): + val_cnt = val_cnt + 1.0 + imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'] + + if(hypar["model_digit"]=="full"): + inputs_val = inputs_val.type(torch.FloatTensor) + labels_val = labels_val.type(torch.FloatTensor) + else: + inputs_val = inputs_val.type(torch.HalfTensor) + labels_val = labels_val.type(torch.HalfTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False) + else: + inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False) + + t_start = time.time() + ds_val = net(inputs_val_v)[0] + t_end = time.time()-t_start + tmp_time.append(t_end) + + # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v) + loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v) + + # compute F measure + for t in range(hypar["batch_size_valid"]): + i_test = imidx_val[t].data.numpy() + + pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W + + ## recover the prediction spatial size to the orignal image size + pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear')) + + # pred_val = normPRED(pred_val) + ma = torch.max(pred_val) + mi = torch.min(pred_val) + pred_val = (pred_val-mi)/(ma-mi) # max = 1 + + gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 + with torch.no_grad(): + gt = torch.tensor(gt).cuda() + + pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) + + + PRE[i_test,:]=pre + REC[i_test,:] = rec + F1[i_test,:] = f1 + MAE[i_test] = mae + + del ds_val, gt + gc.collect() + torch.cuda.empty_cache() + + # if(loss_val.data[0]>1): + val_loss += loss_val.item()#data[0] + tar_loss += loss2_val.item()#data[0] + + print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end)) + + del loss2_val, loss_val + + print('============================') + PRE_m = np.mean(PRE,0) + REC_m = np.mean(REC,0) + f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8) + + tmp_f1.append(np.amax(f1_m)) + tmp_mae.append(np.mean(MAE)) + + return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time + +def main(train_datasets, + valid_datasets, + hypar): # model: "train", "test" + + ### --- Step 1: Build datasets and dataloaders --- + dataloaders_train = [] + dataloaders_valid = [] + + if(hypar["mode"]=="train"): + print("--- create training dataloader ---") + ## collect training dataset + train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train") + ## build dataloader for training datasets + train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list, + cache_size = hypar["cache_size"], + cache_boost = hypar["cache_boost_train"], + my_transforms = [ + GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation + # GOSResize(hypar["input_size"]), + # GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation + GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), + ], + batch_size = hypar["batch_size_train"], + shuffle = True) + train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list, + cache_size = hypar["cache_size"], + cache_boost = hypar["cache_boost_train"], + my_transforms = [ + GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), + ], + batch_size = hypar["batch_size_valid"], + shuffle = False) + print(len(train_dataloaders), " train dataloaders created") + + print("--- create valid dataloader ---") + ## build dataloader for validation or testing + valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") + ## build dataloader for training datasets + valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list, + cache_size = hypar["cache_size"], + cache_boost = hypar["cache_boost_valid"], + my_transforms = [ + GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), + # GOSResize(hypar["input_size"]) + ], + batch_size=hypar["batch_size_valid"], + shuffle=False) + print(len(valid_dataloaders), " valid dataloaders created") + # print(valid_datasets[0]["data_name"]) + + ### --- Step 2: Build Model and Optimizer --- + print("--- build model ---") + net = hypar["model"]#GOSNETINC(3,1) + + # convert to half precision + if(hypar["model_digit"]=="half"): + net.half() + for layer in net.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.float() + + if torch.cuda.is_available(): + net.cuda() + + if(hypar["restore_model"]!=""): + print("restore model from:") + print(hypar["model_path"]+"/"+hypar["restore_model"]) + if torch.cuda.is_available(): + net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"])) + else: + net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu")) + + print("--- define optimizer ---") + optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + + ### --- Step 3: Train or Valid Model --- + if(hypar["mode"]=="train"): + train(net, + optimizer, + train_dataloaders, + train_datasets, + valid_dataloaders, + valid_datasets, + hypar, + train_dataloaders_val, train_datasets_val) + else: + valid(net, + valid_dataloaders, + valid_datasets, + hypar) + + +if __name__ == "__main__": + + ### --------------- STEP 1: Configuring the Train, Valid and Test datasets --------------- + ## configure the train, valid and inference datasets + train_datasets, valid_datasets = [], [] + dataset_1, dataset_1 = {}, {} + + dataset_tr = {"name": "DIS5K-TR", + "im_dir": "../DIS5K/DIS-TR/im", + "gt_dir": "../DIS5K/DIS-TR/gt", + "im_ext": ".jpg", + "gt_ext": ".png", + "cache_dir":"../DIS5K-Cache/DIS-TR"} + + dataset_vd = {"name": "DIS5K-VD", + "im_dir": "../DIS5K/DIS-VD/im", + "gt_dir": "../DIS5K/DIS-VD/gt", + "im_ext": ".jpg", + "gt_ext": ".png", + "cache_dir":"../DIS5K-Cache/DIS-VD"} + + dataset_te1 = {"name": "DIS5K-TE1", + "im_dir": "../DIS5K/DIS-TE1/im", + "gt_dir": "../DIS5K/DIS-TE1/gt", + "im_ext": ".jpg", + "gt_ext": ".png", + "cache_dir":"../DIS5K-Cache/DIS-TE1"} + + dataset_te2 = {"name": "DIS5K-TE2", + "im_dir": "../DIS5K/DIS-TE2/im", + "gt_dir": "../DIS5K/DIS-TE2/gt", + "im_ext": ".jpg", + "gt_ext": ".png", + "cache_dir":"../DIS5K-Cache/DIS-TE2"} + + dataset_te3 = {"name": "DIS5K-TE3", + "im_dir": "../DIS5K/DIS-TE3/im", + "gt_dir": "../DIS5K/DIS-TE3/gt", + "im_ext": ".jpg", + "gt_ext": ".png", + "cache_dir":"../DIS5K-Cache/DIS-TE3"} + + dataset_te4 = {"name": "DIS5K-TE4", + "im_dir": "../DIS5K/DIS-TE4/im", + "gt_dir": "../DIS5K/DIS-TE4/gt", + "im_ext": ".jpg", + "gt_ext": ".png", + "cache_dir":"../DIS5K-Cache/DIS-TE4"} + + train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set + # valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets + valid_datasets = [dataset_vd] #, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference, + + ### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing --------------- + hypar = {} + + ## -- 2.1. configure the model saving or restoring path -- + hypar["mode"] = "train" + ## "train": for training, + ## "valid": for validation and inferening, + ## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be "" + ## otherwise only accuracy will be calculated and no predictions will be saved + hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision + + if hypar["mode"] == "train": + hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory + hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path + hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing + hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process + hypar["gt_encoder_model"] = "" + else: ## configure the segmentation output path and the to-be-used model weights path + hypar["valid_out_dir"] = "../DIS5K-Results-test" ## output inferenced segmentation maps into this fold + hypar["model_path"] ="../saved_models/IS-Net" ## load trained weights from this path + hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights + + # if hypar["restore_model"]!="": + # hypar["start_ite"] = int(hypar["restore_model"].split("_")[2]) + + ## -- 2.2. choose floating point accuracy -- + hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number + hypar["seed"] = 0 + + ## -- 2.3. cache data spatial size -- + ## To handle large size input images, which take a lot of time for loading in training, + # we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file + hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size + hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM + hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM + + ## --- 2.4. data augmentation parameters --- + hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images + hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation + hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the dataloader and it is not in use + hypar["random_flip_v"] = 0 ## vertical flip , currently not in use + + ## --- 2.5. define model --- + print("building model...") + hypar["model"] = ISNetDIS() #U2NETFASTFEATURESUP() + hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10. + hypar["model_save_fre"] = 2000 ## valid and save model weights every 2000 iterations + + hypar["batch_size_train"] = 8 ## batch size for training + hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing + print("batch size: ", hypar["batch_size_train"]) + + hypar["max_ite"] = 10000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num + hypar["max_epoch_num"] = 1000000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num + + main(train_datasets, + valid_datasets, + hypar=hypar) diff --git a/README.md b/README.md index 3ed48c5..2197fae 100644 --- a/README.md +++ b/README.md @@ -2,27 +2,147 @@

+![dis5k-v1-sailship](figures/dis5k-v1-sailship.jpeg) + +
+ ## [Highly Accurate Dichotomous Image Segmentation (ECCV 2022)](https://arxiv.org/pdf/2203.03041.pdf) -[Xuebin Qin](https://xuebinqin.github.io/), [Hang Dai](https://scholar.google.co.uk/citations?user=6yvjpQQAAAAJ&hl=en), [Xiaobin Hu](https://scholar.google.de/citations?user=3lMuodUAAAAJ&hl=en), [Deng-Ping Fan*](https://dengpingfan.github.io/), [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en) and [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en). +[Xuebin Qin](https://xuebinqin.github.io/), [Hang Dai](https://scholar.google.co.uk/citations?user=6yvjpQQAAAAJ&hl=en), [Xiaobin Hu](https://scholar.google.de/citations?user=3lMuodUAAAAJ&hl=en), [Deng-Ping Fan*](https://dengpingfan.github.io/), [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en), [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en). +
-This is the official repo for our new project DIS: -[**[Project Page]**](https://xuebinqin.github.io/dis/index.html) +## This is the official repo for our newly formulated DIS task: +[**Project Page**](https://xuebinqin.github.io/dis/index.html), [**Arxiv**](https://arxiv.org/pdf/2203.03041.pdf). + +
## Updates !!! +
+ +## ** (2022-Jul.-17)** Our paper, code and dataset are now officially released!!! Please check our project page for more details: [**Project Page**](https://xuebinqin.github.io/dis/index.html).
** (2022-Jul.-5)** Our DIS work is now accepted by ECCV 2022, the code and dataset will be released before July 17th, 2022. Please be aware of our updates. -## DEMO -![ship-demo](figures/ship-demo.gif) -![bg-removal](figures/bg-removal.gif) -![view-move](figures/view-move.gif) -![motor-demo](figures/motor-demo.gif) +
-## Comparisons Against SOTAs +## 1. [Our DIS5K Dataset V1.0 (Version Alias: DIS5K Sailship)](https://xuebinqin.github.io/dis/index.html) + +
+ +### Download: [Google Drive](https://drive.google.com/file/d/1jOC2zK0GowBvEt03B7dGugCRDVAoeIqq/view?usp=sharing) or [Baidu Pan 提取码:rtgw](https://pan.baidu.com/s/1y6CQJYledfYyEO0C_Gejpw?pwd=rtgw) + +![dis5k-dataset-v1-sailship](figures/DIS5k-dataset-v1-sailship.png) +![complexities-qual](figures/complexities-qual.jpeg) +![categories](figures/categories.jpeg) + +
+ +## 2. APPLICATIONS of Our DIS5K Dataset + +
+ +### 3D Modeling +![3d-modeling](figures/3d-modeling.png) + +### Image Editing +![ship-demo](figures/ship-demo.gif) +### Art Design Materials +![bg-removal](figures/bg-removal.gif) +### Still Image Animation +![view-move](figures/view-move.gif) +### AR +![motor-demo](figures/motor-demo.gif) +### 3D Rendering +![video-3d](figures/video-3d.gif) + +
+ +## 3. Architecture of Our IS-Net + +
+ +![is-net](figures/is-net.png) + +
+ +## 4. Human Correction Efforts (HCE) + +
+ +![hce-metric](figures/hce-metric.png) + +
+ +## 5. Experimental Results + +
+ +### Predicted Maps, [(Google Drive)](https://drive.google.com/file/d/1FMtDLFrL6xVc41eKlLnuZWMBAErnKv0Y/view?usp=sharing), [(Baidu Pan 提取码:ph1d)](https://pan.baidu.com/s/1WUk2RYYpii2xzrvLna9Fsg?pwd=ph1d), of Our IS-Net and Other SOTAs + +### Qualitative Comparisons Against SOTAs ![qual-comp](figures/qual-comp.jpg) +### Quantitative Comparisons Against SOTAs +![qual-comp](figures/quan-comp.png) + +
+ +## 6. Run Our Code + +
+ +### (1) Clone this repo +``` +git clone https://github.com/xuebinqin/DIS.git +``` + +### (2) Configuring the environment: go to the root ```DIS``` folder and run +``` +conda env create -f pytorch18.yml +``` +Or you can check the ```requirements.txt``` to configure the dependancies. + +### (3) Train: +#### (a) Open ```train_valid_inference_main.py```, set the path of your to-be-inferenced ```train_datasets``` and ```valid_datasets```, e.g., ```valid_datasets=[dataset_vd]``` +#### (b) Set the ```hypar["mode"]``` to ```"train"``` +#### (c) Create a new folder ```your_model_weights``` in the directory ```saved_models``` and set it as the ```hypar["model_path"] ="../saved_models/your_model_weights"``` and make sure ```hypar["valid_out_dir"]```(line 668) is set to ```""```, otherwise the prediction maps of the validation stage will be saved to that directory, which will slow the training speed down +#### (d) Run +``` +python train_valid_inference_main.py +``` + + +### (4) Inference +#### (a). Download the pre-trained weights (for fair academic comparisons only, the optimized model for engineering or common use will be released soon) ```isnet.pth``` from [(Google Drive)](https://drive.google.com/file/d/1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn/view?usp=sharing) or [(Baidu Pan 提取码:xbfk)](https://pan.baidu.com/s/1-X2WutiBkWPt-oakuvZ10w?pwd=xbfk) and store ```isnet.pth``` in ```saved_models/IS-Net``` +#### (b) Open ```train_valid_inference_main.py```, set the path of your to-be-inferenced ```valid_datasets```, e.g., ```valid_datasets=[dataset_te1, dataset_te2, dataset_te3, dataset_te4]``` +#### (c) Set the ```hypar["mode"]``` to ```"valid"``` +#### (d) Set the output directory of your predicted maps, e.g., ```hypar["valid_out_dir"] = "../DIS5K-Results-test"``` +#### (e) Run +``` +python train_valid_inference_main.py +``` + +### (5) Use of our Human Correction Efforts(HCE) metric, set the ground truth directory ```gt_root``` and the prediction directory ```pred_root```. To reduce the time costs for computing HCE, the skeletion of the DIS5K dataset can be pre-computed and stored in ```gt_ske_root```. If ```gt_ske_root=""```, the HCE code will compute the skeleton online which usually takes a lot for time for large size ground truth. Then, run ```python hce_metric_main.py```. Other metrics are evaluated based on the [SOCToolbox](https://github.com/mczhuge/SOCToolbox). + +
+ +## 7. Term of Use +Our code and evaluation metric use Apache License 2.0. The Terms of use for our DIS5K dataset is provided as [DIS5K-Dataset-Terms-of-Use.pdf](DIS5K-Dataset-Terms-of-Use.pdf). + +
+ +## Acknowledgements + +
+ +We would like to thank Dr. [Ibrahim Almakky](https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en) for his helps in implementing the dataloader cache machanism of loading large-size training samples and Jiayi Zhu for his efforts in re-organizing our code and dataset. + +
+ ## Citation + +
+ ``` @InProceedings{qin2022, author={Xuebin Qin and Hang Dai and Xiaobin Hu and Deng-Ping Fan and Ling Shao and Luc Van Gool}, @@ -31,8 +151,15 @@ This is the official repo for our new project DIS: year={2022} } ``` -## Our Previous Works + +
+ +## Our Previous Works: [U2-Net](https://github.com/xuebinqin/U-2-Net), [BASNet](https://github.com/xuebinqin/BASNet). + +
+ ``` + @InProceedings{Qin_2020_PR, title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection}, author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin}, @@ -42,13 +169,6 @@ This is the official repo for our new project DIS: year = {2020} } -@article{qin2021boundary, - title={Boundary-aware segmentation network for mobile and web applications}, - author={Qin, Xuebin and Fan, Deng-Ping and Huang, Chenyang and Diagne, Cyril and Zhang, Zichen and Sant'Anna, Adri{\`a} Cabeza and Suarez, Albert and Jagersand, Martin and Shao, Ling}, - journal={arXiv preprint arXiv:2101.04704}, - year={2021} -} - @InProceedings{Qin_2019_CVPR, author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Gao, Chao and Dehghan, Masood and Jagersand, Martin}, title = {BASNet: Boundary-Aware Salient Object Detection}, @@ -56,3 +176,10 @@ This is the official repo for our new project DIS: month = {June}, year = {2019} } + +@article{qin2021boundary, + title={Boundary-aware segmentation network for mobile and web applications}, + author={Qin, Xuebin and Fan, Deng-Ping and Huang, Chenyang and Diagne, Cyril and Zhang, Zichen and Sant'Anna, Adri{\`a} Cabeza and Suarez, Albert and Jagersand, Martin and Shao, Ling}, + journal={arXiv preprint arXiv:2101.04704}, + year={2021} +} \ No newline at end of file diff --git a/benchmark b/benchmark deleted file mode 100644 index 8b13789..0000000 --- a/benchmark +++ /dev/null @@ -1 +0,0 @@ - diff --git a/figures/.DS_Store b/figures/.DS_Store new file mode 100644 index 0000000..98c272b Binary files /dev/null and b/figures/.DS_Store differ diff --git a/figures/3d-modeling.png b/figures/3d-modeling.png new file mode 100644 index 0000000..fbf36be Binary files /dev/null and b/figures/3d-modeling.png differ diff --git a/figures/DIS5k-dataset-v1-sailship.png b/figures/DIS5k-dataset-v1-sailship.png new file mode 100644 index 0000000..df4b412 Binary files /dev/null and b/figures/DIS5k-dataset-v1-sailship.png differ diff --git a/figures/categories.jpeg b/figures/categories.jpeg new file mode 100644 index 0000000..4e50a8f Binary files /dev/null and b/figures/categories.jpeg differ diff --git a/figures/complexities-qual.jpeg b/figures/complexities-qual.jpeg new file mode 100644 index 0000000..22242db Binary files /dev/null and b/figures/complexities-qual.jpeg differ diff --git a/figures/dis5k-v1-sailship.jpeg b/figures/dis5k-v1-sailship.jpeg new file mode 100644 index 0000000..9cfb427 Binary files /dev/null and b/figures/dis5k-v1-sailship.jpeg differ diff --git a/figures/hce-metric.png b/figures/hce-metric.png new file mode 100644 index 0000000..f39698f Binary files /dev/null and b/figures/hce-metric.png differ diff --git a/figures/is-net.png b/figures/is-net.png new file mode 100644 index 0000000..c2517a3 Binary files /dev/null and b/figures/is-net.png differ diff --git a/figures/quan-comp.png b/figures/quan-comp.png new file mode 100644 index 0000000..480206c Binary files /dev/null and b/figures/quan-comp.png differ diff --git a/figures/statistics.jpeg b/figures/statistics.jpeg new file mode 100644 index 0000000..6e10661 Binary files /dev/null and b/figures/statistics.jpeg differ diff --git a/figures/video-3d.gif b/figures/video-3d.gif new file mode 100644 index 0000000..9433cdc Binary files /dev/null and b/figures/video-3d.gif differ diff --git a/saved_models/.DS_Store b/saved_models/.DS_Store new file mode 100644 index 0000000..ce35f61 Binary files /dev/null and b/saved_models/.DS_Store differ