official release of our isnet and dis5k
BIN
DIS5K-Dataset-Terms-of-Use.pdf
Normal file
BIN
IS-Net/__pycache__/basics.cpython-37.pyc
Normal file
BIN
IS-Net/__pycache__/data_loader_cache.cpython-37.pyc
Normal file
74
IS-Net/basics.py
Normal file
@ -0,0 +1,74 @@
|
||||
import os
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
||||
from skimage import io, transform
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms, utils
|
||||
import torch.optim as optim
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import glob
|
||||
|
||||
def mae_torch(pred,gt):
|
||||
|
||||
h,w = gt.shape[0:2]
|
||||
sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
|
||||
maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4)
|
||||
|
||||
return maeError
|
||||
|
||||
def f1score_torch(pd,gt):
|
||||
|
||||
# print(gt.shape)
|
||||
gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels
|
||||
|
||||
pp = pd[gt>128]
|
||||
nn = pd[gt<=128]
|
||||
|
||||
pp_hist =torch.histc(pp,bins=255,min=0,max=255)
|
||||
nn_hist = torch.histc(nn,bins=255,min=0,max=255)
|
||||
|
||||
|
||||
pp_hist_flip = torch.flipud(pp_hist)
|
||||
nn_hist_flip = torch.flipud(nn_hist)
|
||||
|
||||
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
|
||||
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
|
||||
|
||||
precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
|
||||
recall = (pp_hist_flip_cum)/(gtNum + 1e-4)
|
||||
f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4)
|
||||
|
||||
return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0]))
|
||||
|
||||
|
||||
def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
|
||||
|
||||
import time
|
||||
tic = time.time()
|
||||
|
||||
if(len(gt.shape)>2):
|
||||
gt = gt[:,:,0]
|
||||
|
||||
pre, rec, f1 = f1score_torch(pred,gt)
|
||||
mae = mae_torch(pred,gt)
|
||||
|
||||
|
||||
# hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ###
|
||||
if(hypar["valid_out_dir"]!=""):
|
||||
if(not os.path.exists(hypar["valid_out_dir"])):
|
||||
os.mkdir(hypar["valid_out_dir"])
|
||||
dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx])
|
||||
if(not os.path.exists(dataset_folder)):
|
||||
os.mkdir(dataset_folder)
|
||||
io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8))
|
||||
print(valid_dataset.dataset["im_name"][idx]+".png")
|
||||
print("time for evaluation : ", time.time()-tic)
|
||||
|
||||
return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy()
|
380
IS-Net/data_loader_cache.py
Normal file
@ -0,0 +1,380 @@
|
||||
## data loader
|
||||
## Ackownledgement:
|
||||
## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
|
||||
## for his helps in implementing cache machanism of our DIS dataloader.
|
||||
from __future__ import print_function, division
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from skimage import io
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms, utils
|
||||
from torchvision.transforms.functional import normalize
|
||||
import torch.nn.functional as F
|
||||
|
||||
#### --------------------- DIS dataloader cache ---------------------####
|
||||
|
||||
def get_im_gt_name_dict(datasets, flag='valid'):
|
||||
print("------------------------------", flag, "--------------------------------")
|
||||
name_im_gt_list = []
|
||||
for i in range(len(datasets)):
|
||||
print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
|
||||
tmp_im_list, tmp_gt_list = [], []
|
||||
tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
|
||||
|
||||
# img_name_dict[im_dirs[i][0]] = tmp_im_list
|
||||
print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
|
||||
|
||||
if(datasets[i]["gt_dir"]==""):
|
||||
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
|
||||
else:
|
||||
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
|
||||
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
|
||||
print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
|
||||
|
||||
|
||||
if flag=="train": ## combine multiple training sets into one dataset
|
||||
if len(name_im_gt_list)==0:
|
||||
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
|
||||
"im_path":tmp_im_list,
|
||||
"gt_path":tmp_gt_list,
|
||||
"im_ext":datasets[i]["im_ext"],
|
||||
"gt_ext":datasets[i]["gt_ext"],
|
||||
"cache_dir":datasets[i]["cache_dir"]})
|
||||
else:
|
||||
name_im_gt_list[0]["dataset_name"] = name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
|
||||
name_im_gt_list[0]["im_path"] = name_im_gt_list[0]["im_path"] + tmp_im_list
|
||||
name_im_gt_list[0]["gt_path"] = name_im_gt_list[0]["gt_path"] + tmp_gt_list
|
||||
if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
|
||||
print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
|
||||
exit()
|
||||
name_im_gt_list[0]["im_ext"] = ".jpg"
|
||||
name_im_gt_list[0]["gt_ext"] = ".png"
|
||||
name_im_gt_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_list[0]["dataset_name"]
|
||||
else: ## keep different validation or inference datasets as separate ones
|
||||
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
|
||||
"im_path":tmp_im_list,
|
||||
"gt_path":tmp_gt_list,
|
||||
"im_ext":datasets[i]["im_ext"],
|
||||
"gt_ext":datasets[i]["gt_ext"],
|
||||
"cache_dir":datasets[i]["cache_dir"]})
|
||||
|
||||
return name_im_gt_list
|
||||
|
||||
def create_dataloaders(name_im_gt_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False):
|
||||
## model="train": return one dataloader for training
|
||||
## model="valid": return a list of dataloaders for validation or testing
|
||||
|
||||
gos_dataloaders = []
|
||||
gos_datasets = []
|
||||
# if(mode=="train"):
|
||||
if(len(name_im_gt_list)==0):
|
||||
return gos_dataloaders, gos_datasets
|
||||
|
||||
num_workers_ = 1
|
||||
if(batch_size>1):
|
||||
num_workers_ = 2
|
||||
if(batch_size>4):
|
||||
num_workers_ = 4
|
||||
if(batch_size>8):
|
||||
num_workers_ = 8
|
||||
|
||||
for i in range(0,len(name_im_gt_list)):
|
||||
gos_dataset = GOSDatasetCache([name_im_gt_list[i]],
|
||||
cache_size = cache_size,
|
||||
cache_path = name_im_gt_list[i]["cache_dir"],
|
||||
cache_boost = cache_boost,
|
||||
transform = transforms.Compose(my_transforms))
|
||||
gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
|
||||
gos_datasets.append(gos_dataset)
|
||||
|
||||
return gos_dataloaders, gos_datasets
|
||||
|
||||
def im_reader(im_path):
|
||||
return io.imread(im_path)
|
||||
|
||||
def im_preprocess(im,size):
|
||||
if len(im.shape) < 3:
|
||||
im = im[:, :, np.newaxis]
|
||||
if im.shape[2] == 1:
|
||||
im = np.repeat(im, 3, axis=2)
|
||||
im_tensor = torch.tensor(im, dtype=torch.float32)
|
||||
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
||||
if(len(size)<2):
|
||||
return im_tensor, im.shape[0:2]
|
||||
else:
|
||||
im_tensor = torch.unsqueeze(im_tensor,0)
|
||||
im_tensor = F.upsample(im_tensor, size, mode="bilinear")
|
||||
im_tensor = torch.squeeze(im_tensor,0)
|
||||
|
||||
return im_tensor.type(torch.uint8), im.shape[0:2]
|
||||
|
||||
def gt_preprocess(gt,size):
|
||||
if len(gt.shape) > 2:
|
||||
gt = gt[:, :, 0]
|
||||
|
||||
gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
|
||||
|
||||
if(len(size)<2):
|
||||
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
||||
else:
|
||||
gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
|
||||
gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
|
||||
gt_tensor = torch.squeeze(gt_tensor,0)
|
||||
|
||||
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
||||
# return gt_tensor, gt.shape[0:2]
|
||||
|
||||
class GOSRandomHFlip(object):
|
||||
def __init__(self,prob=0.5):
|
||||
self.prob = prob
|
||||
def __call__(self,sample):
|
||||
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
||||
|
||||
# random horizontal flip
|
||||
if random.random() >= self.prob:
|
||||
image = torch.flip(image,dims=[2])
|
||||
label = torch.flip(label,dims=[2])
|
||||
|
||||
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
||||
|
||||
class GOSResize(object):
|
||||
def __init__(self,size=[320,320]):
|
||||
self.size = size
|
||||
def __call__(self,sample):
|
||||
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
||||
|
||||
# import time
|
||||
# start = time.time()
|
||||
|
||||
image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
|
||||
label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
|
||||
|
||||
# print("time for resize: ", time.time()-start)
|
||||
|
||||
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
||||
|
||||
class GOSRandomCrop(object):
|
||||
def __init__(self,size=[288,288]):
|
||||
self.size = size
|
||||
def __call__(self,sample):
|
||||
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
||||
|
||||
h, w = image.shape[1:]
|
||||
new_h, new_w = self.size
|
||||
|
||||
top = np.random.randint(0, h - new_h)
|
||||
left = np.random.randint(0, w - new_w)
|
||||
|
||||
image = image[:,top:top+new_h,left:left+new_w]
|
||||
label = label[:,top:top+new_h,left:left+new_w]
|
||||
|
||||
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
||||
|
||||
|
||||
class GOSNormalize(object):
|
||||
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self,sample):
|
||||
|
||||
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
||||
image = normalize(image,self.mean,self.std)
|
||||
|
||||
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
||||
|
||||
|
||||
class GOSDatasetCache(Dataset):
|
||||
|
||||
def __init__(self, name_im_gt_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None):
|
||||
|
||||
|
||||
self.cache_size = cache_size
|
||||
self.cache_path = cache_path
|
||||
self.cache_file_name = cache_file_name
|
||||
self.cache_boost_name = ""
|
||||
|
||||
self.cache_boost = cache_boost
|
||||
# self.ims_npy = None
|
||||
# self.gts_npy = None
|
||||
|
||||
## cache all the images and ground truth into a single pytorch tensor
|
||||
self.ims_pt = None
|
||||
self.gts_pt = None
|
||||
|
||||
## we will cache the npy as well regardless of the cache_boost
|
||||
# if(self.cache_boost):
|
||||
self.cache_boost_name = cache_file_name.split('.json')[0]
|
||||
|
||||
self.transform = transform
|
||||
|
||||
self.dataset = {}
|
||||
|
||||
## combine different datasets into one
|
||||
dataset_names = []
|
||||
dt_name_list = [] # dataset name per image
|
||||
im_name_list = [] # image name
|
||||
im_path_list = [] # im path
|
||||
gt_path_list = [] # gt path
|
||||
im_ext_list = [] # im ext
|
||||
gt_ext_list = [] # gt ext
|
||||
for i in range(0,len(name_im_gt_list)):
|
||||
dataset_names.append(name_im_gt_list[i]["dataset_name"])
|
||||
# dataset name repeated based on the number of images in this dataset
|
||||
dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]])
|
||||
im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]])
|
||||
im_path_list.extend(name_im_gt_list[i]["im_path"])
|
||||
gt_path_list.extend(name_im_gt_list[i]["gt_path"])
|
||||
im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]])
|
||||
gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]])
|
||||
|
||||
|
||||
self.dataset["data_name"] = dt_name_list
|
||||
self.dataset["im_name"] = im_name_list
|
||||
self.dataset["im_path"] = im_path_list
|
||||
self.dataset["ori_im_path"] = deepcopy(im_path_list)
|
||||
self.dataset["gt_path"] = gt_path_list
|
||||
self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
|
||||
self.dataset["im_shp"] = []
|
||||
self.dataset["gt_shp"] = []
|
||||
self.dataset["im_ext"] = im_ext_list
|
||||
self.dataset["gt_ext"] = gt_ext_list
|
||||
|
||||
|
||||
self.dataset["ims_pt_dir"] = ""
|
||||
self.dataset["gts_pt_dir"] = ""
|
||||
|
||||
self.dataset = self.manage_cache(dataset_names)
|
||||
|
||||
def manage_cache(self,dataset_names):
|
||||
if not os.path.exists(self.cache_path): # create the folder for cache
|
||||
os.mkdir(self.cache_path)
|
||||
cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
|
||||
if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
|
||||
return self.cache(cache_folder)
|
||||
return self.load_cache(cache_folder)
|
||||
|
||||
def cache(self,cache_folder):
|
||||
os.mkdir(cache_folder)
|
||||
cached_dataset = deepcopy(self.dataset)
|
||||
|
||||
# ims_list = []
|
||||
# gts_list = []
|
||||
ims_pt_list = []
|
||||
gts_pt_list = []
|
||||
for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
|
||||
|
||||
|
||||
im_id = cached_dataset["im_name"][i]
|
||||
|
||||
im = im_reader(im_path)
|
||||
im, im_shp = im_preprocess(im,self.cache_size)
|
||||
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
|
||||
torch.save(im,im_cache_file)
|
||||
|
||||
cached_dataset["im_path"][i] = im_cache_file
|
||||
if(self.cache_boost):
|
||||
ims_pt_list.append(torch.unsqueeze(im,0))
|
||||
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
|
||||
|
||||
|
||||
gt = im_reader(self.dataset["gt_path"][i])
|
||||
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
||||
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
||||
torch.save(gt,gt_cache_file)
|
||||
cached_dataset["gt_path"][i] = gt_cache_file
|
||||
if(self.cache_boost):
|
||||
gts_pt_list.append(torch.unsqueeze(gt,0))
|
||||
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
||||
|
||||
# im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
|
||||
# torch.save(gt_shp, shp_cache_file)
|
||||
cached_dataset["im_shp"].append(im_shp)
|
||||
# self.dataset["im_shp"].append(im_shp)
|
||||
|
||||
# shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
|
||||
# torch.save(gt_shp, shp_cache_file)
|
||||
cached_dataset["gt_shp"].append(gt_shp)
|
||||
# self.dataset["gt_shp"].append(gt_shp)
|
||||
|
||||
if(self.cache_boost):
|
||||
cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
|
||||
cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
|
||||
self.ims_pt = torch.cat(ims_pt_list,dim=0)
|
||||
self.gts_pt = torch.cat(gts_pt_list,dim=0)
|
||||
torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
|
||||
torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
|
||||
|
||||
try:
|
||||
json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
|
||||
json.dump(cached_dataset, json_file)
|
||||
json_file.close()
|
||||
except Exception:
|
||||
raise FileNotFoundError("Cannot create JSON")
|
||||
return cached_dataset
|
||||
|
||||
def load_cache(self, cache_folder):
|
||||
json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
|
||||
dataset = json.load(json_file)
|
||||
json_file.close()
|
||||
## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
|
||||
## otherwise the pytorch tensor will be loaded
|
||||
if(self.cache_boost):
|
||||
# self.ims_npy = np.load(dataset["ims_npy_dir"])
|
||||
# self.gts_npy = np.load(dataset["gts_npy_dir"])
|
||||
self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
|
||||
self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset["im_path"])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
im = None
|
||||
gt = None
|
||||
if(self.cache_boost and self.ims_pt is not None):
|
||||
|
||||
# start = time.time()
|
||||
im = self.ims_pt[idx]#.type(torch.float32)
|
||||
gt = self.gts_pt[idx]#.type(torch.float32)
|
||||
# print(idx, 'time for pt loading: ', time.time()-start)
|
||||
|
||||
else:
|
||||
# import time
|
||||
# start = time.time()
|
||||
# print("tensor***")
|
||||
im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
|
||||
im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
|
||||
gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
|
||||
gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
|
||||
# print(idx,'time for tensor loading: ', time.time()-start)
|
||||
|
||||
|
||||
im_shp = self.dataset["im_shp"][idx]
|
||||
# print("time for loading im and gt: ", time.time()-start)
|
||||
|
||||
# start_time = time.time()
|
||||
im = torch.divide(im,255.0)
|
||||
gt = torch.divide(gt,255.0)
|
||||
# print(idx, 'time for normalize torch divide: ', time.time()-start_time)
|
||||
|
||||
sample = {
|
||||
"imidx": torch.from_numpy(np.array(idx)),
|
||||
"image": im,
|
||||
"label": gt,
|
||||
"shape": torch.from_numpy(np.array(im_shp)),
|
||||
}
|
||||
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
|
||||
return sample
|
188
IS-Net/hce_metric_main.py
Normal file
@ -0,0 +1,188 @@
|
||||
## hce_metric.py
|
||||
import numpy as np
|
||||
from skimage import io
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2 as cv
|
||||
from skimage.morphology import skeletonize
|
||||
from skimage.morphology import erosion, dilation, disk
|
||||
from skimage.measure import label
|
||||
|
||||
import os
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
from glob import glob
|
||||
import pickle as pkl
|
||||
|
||||
def filter_bdy_cond(bdy_, mask, cond):
|
||||
|
||||
cond = cv.dilate(cond.astype(np.uint8),disk(1))
|
||||
labels = label(mask) # find the connected regions
|
||||
lbls = np.unique(labels) # the indices of the connected regions
|
||||
indep = np.ones(lbls.shape[0]) # the label of each connected regions
|
||||
indep[0] = 0 # 0 indicate the background region
|
||||
|
||||
boundaries = []
|
||||
h,w = cond.shape[0:2]
|
||||
ind_map = np.zeros((h,w))
|
||||
indep_cnt = 0
|
||||
|
||||
for i in range(0,len(bdy_)):
|
||||
tmp_bdies = []
|
||||
tmp_bdy = []
|
||||
for j in range(0,bdy_[i].shape[0]):
|
||||
r, c = bdy_[i][j,0,1],bdy_[i][j,0,0]
|
||||
|
||||
if(np.sum(cond[r,c])==0 or ind_map[r,c]!=0):
|
||||
if(len(tmp_bdy)>0):
|
||||
tmp_bdies.append(tmp_bdy)
|
||||
tmp_bdy = []
|
||||
continue
|
||||
tmp_bdy.append([c,r])
|
||||
ind_map[r,c] = ind_map[r,c] + 1
|
||||
indep[labels[r,c]] = 0 # indicates part of the boundary of this region needs human correction
|
||||
if(len(tmp_bdy)>0):
|
||||
tmp_bdies.append(tmp_bdy)
|
||||
|
||||
# check if the first and the last boundaries are connected
|
||||
# if yes, invert the first boundary and attach it after the last boundary
|
||||
if(len(tmp_bdies)>1):
|
||||
first_x, first_y = tmp_bdies[0][0]
|
||||
last_x, last_y = tmp_bdies[-1][-1]
|
||||
if((abs(first_x-last_x)==1 and first_y==last_y) or
|
||||
(first_x==last_x and abs(first_y-last_y)==1) or
|
||||
(abs(first_x-last_x)==1 and abs(first_y-last_y)==1)
|
||||
):
|
||||
tmp_bdies[-1].extend(tmp_bdies[0][::-1])
|
||||
del tmp_bdies[0]
|
||||
|
||||
for k in range(0,len(tmp_bdies)):
|
||||
tmp_bdies[k] = np.array(tmp_bdies[k])[:,np.newaxis,:]
|
||||
if(len(tmp_bdies)>0):
|
||||
boundaries.extend(tmp_bdies)
|
||||
|
||||
return boundaries, np.sum(indep)
|
||||
|
||||
# this function approximate each boundary by DP algorithm
|
||||
# https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm
|
||||
def approximate_RDP(boundaries,epsilon=1.0):
|
||||
|
||||
boundaries_ = []
|
||||
boundaries_len_ = []
|
||||
pixel_cnt_ = 0
|
||||
|
||||
# polygon approximate of each boundary
|
||||
for i in range(0,len(boundaries)):
|
||||
boundaries_.append(cv.approxPolyDP(boundaries[i],epsilon,False))
|
||||
|
||||
# count the control points number of each boundary and the total control points number of all the boundaries
|
||||
for i in range(0,len(boundaries_)):
|
||||
boundaries_len_.append(len(boundaries_[i]))
|
||||
pixel_cnt_ = pixel_cnt_ + len(boundaries_[i])
|
||||
|
||||
return boundaries_, boundaries_len_, pixel_cnt_
|
||||
|
||||
|
||||
def relax_HCE(gt, rs, gt_ske, relax=5, epsilon=2.0):
|
||||
# print("max(gt_ske): ", np.amax(gt_ske))
|
||||
# gt_ske = gt_ske>128
|
||||
# print("max(gt_ske): ", np.amax(gt_ske))
|
||||
|
||||
# Binarize gt
|
||||
if(len(gt.shape)>2):
|
||||
gt = gt[:,:,0]
|
||||
|
||||
epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0
|
||||
gt = (gt>epsilon_gt).astype(np.uint8)
|
||||
|
||||
# Binarize rs
|
||||
if(len(rs.shape)>2):
|
||||
rs = rs[:,:,0]
|
||||
epsilon_rs = 128#(np.amin(rs)+np.amax(rs))/2.0
|
||||
rs = (rs>epsilon_rs).astype(np.uint8)
|
||||
|
||||
Union = np.logical_or(gt,rs)
|
||||
TP = np.logical_and(gt,rs)
|
||||
FP = rs - TP
|
||||
FN = gt - TP
|
||||
|
||||
# relax the Union of gt and rs
|
||||
Union_erode = Union.copy()
|
||||
Union_erode = cv.erode(Union_erode.astype(np.uint8),disk(1),iterations=relax)
|
||||
|
||||
# --- get the relaxed False Positive regions for computing the human efforts in correcting them ---
|
||||
FP_ = np.logical_and(FP,Union_erode) # get the relaxed FP
|
||||
for i in range(0,relax):
|
||||
FP_ = cv.dilate(FP_.astype(np.uint8),disk(1))
|
||||
FP_ = np.logical_and(FP_, 1-np.logical_or(TP,FN))
|
||||
FP_ = np.logical_and(FP, FP_)
|
||||
|
||||
# --- get the relaxed False Negative regions for computing the human efforts in correcting them ---
|
||||
FN_ = np.logical_and(FN,Union_erode) # preserve the structural components of FN
|
||||
## recover the FN, where pixels are not close to the TP borders
|
||||
for i in range(0,relax):
|
||||
FN_ = cv.dilate(FN_.astype(np.uint8),disk(1))
|
||||
FN_ = np.logical_and(FN_,1-np.logical_or(TP,FP))
|
||||
FN_ = np.logical_and(FN,FN_)
|
||||
FN_ = np.logical_or(FN_, np.logical_xor(gt_ske,np.logical_and(TP,gt_ske))) # preserve the structural components of FN
|
||||
|
||||
## 2. =============Find exact polygon control points and independent regions==============
|
||||
## find contours from FP_
|
||||
ctrs_FP, hier_FP = cv.findContours(FP_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
||||
## find control points and independent regions for human correction
|
||||
bdies_FP, indep_cnt_FP = filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_))
|
||||
## find contours from FN_
|
||||
ctrs_FN, hier_FN = cv.findContours(FN_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
||||
## find control points and independent regions for human correction
|
||||
bdies_FN, indep_cnt_FN = filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP,FP_),FN_))
|
||||
|
||||
poly_FP, poly_FP_len, poly_FP_point_cnt = approximate_RDP(bdies_FP,epsilon=epsilon)
|
||||
poly_FN, poly_FN_len, poly_FN_point_cnt = approximate_RDP(bdies_FN,epsilon=epsilon)
|
||||
|
||||
return poly_FP_point_cnt, indep_cnt_FP, poly_FN_point_cnt, indep_cnt_FN
|
||||
|
||||
def compute_hce(pred_root,gt_root,gt_ske_root):
|
||||
|
||||
gt_name_list = glob(pred_root+'/*.png')
|
||||
gt_name_list = sorted([x.split('/')[-1] for x in gt_name_list])
|
||||
|
||||
hces = []
|
||||
for gt_name in tqdm(gt_name_list, total=len(gt_name_list)):
|
||||
gt_path = os.path.join(gt_root, gt_name)
|
||||
pred_path = os.path.join(pred_root, gt_name)
|
||||
|
||||
gt = cv.imread(gt_path, cv.IMREAD_GRAYSCALE)
|
||||
pred = cv.imread(pred_path, cv.IMREAD_GRAYSCALE)
|
||||
|
||||
ske_path = os.path.join(gt_ske_root,gt_name)
|
||||
if os.path.exists(ske_path):
|
||||
ske = cv.imread(ske_path,cv.IMREAD_GRAYSCALE)
|
||||
ske = ske>128
|
||||
else:
|
||||
ske = skeletonize(gt>128)
|
||||
|
||||
FP_points, FP_indep, FN_points, FN_indep = relax_HCE(gt, pred,ske)
|
||||
print(gt_path.split('/')[-1],FP_points, FP_indep, FN_points, FN_indep)
|
||||
hces.append([FP_points, FP_indep, FN_points, FN_indep, FP_points+FP_indep+FN_points+FN_indep])
|
||||
|
||||
hce_metric ={'names': gt_name_list,
|
||||
'hces': hces}
|
||||
|
||||
|
||||
file_metric = open(pred_root+'/hce_metric.pkl','wb')
|
||||
pkl.dump(hce_metric,file_metric)
|
||||
# file_metrics.write(cmn_metrics)
|
||||
file_metric.close()
|
||||
|
||||
return np.mean(np.array(hces)[:,-1])
|
||||
|
||||
def main():
|
||||
|
||||
gt_root = "../DIS5K/DIS-VD/gt"
|
||||
gt_ske_root = ""
|
||||
pred_root = "../Results/isnet(ours)/DIS-VD"
|
||||
|
||||
print("The average HCE metric: ", compute_hce(pred_root,gt_root,gt_ske_root))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
1
IS-Net/models/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from models.isnet import ISNetGTEncoder, ISNetDIS
|
BIN
IS-Net/models/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
IS-Net/models/__pycache__/u2netfast.cpython-37.pyc
Normal file
614
IS-Net/models/isnet.py
Normal file
@ -0,0 +1,614 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
bce_loss = nn.BCELoss(size_average=True)
|
||||
def muti_loss_fusion(preds, target):
|
||||
loss0 = 0.0
|
||||
loss = 0.0
|
||||
|
||||
for i in range(0,len(preds)):
|
||||
# print("i: ", i, preds[i].shape)
|
||||
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
|
||||
# tmp_target = _upsample_like(target,preds[i])
|
||||
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
|
||||
loss = loss + bce_loss(preds[i],tmp_target)
|
||||
else:
|
||||
loss = loss + bce_loss(preds[i],target)
|
||||
if(i==0):
|
||||
loss0 = loss
|
||||
return loss0, loss
|
||||
|
||||
fea_loss = nn.MSELoss(size_average=True)
|
||||
kl_loss = nn.KLDivLoss(size_average=True)
|
||||
l1_loss = nn.L1Loss(size_average=True)
|
||||
smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
|
||||
def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
|
||||
loss0 = 0.0
|
||||
loss = 0.0
|
||||
|
||||
for i in range(0,len(preds)):
|
||||
# print("i: ", i, preds[i].shape)
|
||||
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
|
||||
# tmp_target = _upsample_like(target,preds[i])
|
||||
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
|
||||
loss = loss + bce_loss(preds[i],tmp_target)
|
||||
else:
|
||||
loss = loss + bce_loss(preds[i],target)
|
||||
if(i==0):
|
||||
loss0 = loss
|
||||
|
||||
for i in range(0,len(dfs)):
|
||||
if(mode=='MSE'):
|
||||
loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
|
||||
# print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
|
||||
elif(mode=='KL'):
|
||||
loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
|
||||
# print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
|
||||
elif(mode=='MAE'):
|
||||
loss = loss + l1_loss(dfs[i],fs[i])
|
||||
# print("ls_loss: ", l1_loss(dfs[i],fs[i]))
|
||||
elif(mode=='SmoothL1'):
|
||||
loss = loss + smooth_l1_loss(dfs[i],fs[i])
|
||||
# print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
|
||||
|
||||
return loss0, loss
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
||||
super(REBNCONV,self).__init__()
|
||||
|
||||
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
|
||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||
self.relu_s1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||
|
||||
return xout
|
||||
|
||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||
def _upsample_like(src,tar):
|
||||
|
||||
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
|
||||
|
||||
return src
|
||||
|
||||
|
||||
### RSU-7 ###
|
||||
class RSU7(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
||||
super(RSU7,self).__init__()
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.mid_ch = mid_ch
|
||||
self.out_ch = out_ch
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx = self.pool5(hx5)
|
||||
|
||||
hx6 = self.rebnconv6(hx)
|
||||
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
|
||||
hx6dup = _upsample_like(hx6d,hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-6 ###
|
||||
class RSU6(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU6,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
|
||||
hx6 = self.rebnconv6(hx5)
|
||||
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-5 ###
|
||||
class RSU5(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU5,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-4 ###
|
||||
class RSU4(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-4F ###
|
||||
class RSU4F(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4F,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx2 = self.rebnconv2(hx1)
|
||||
hx3 = self.rebnconv3(hx2)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class myrebnconv(nn.Module):
|
||||
def __init__(self, in_ch=3,
|
||||
out_ch=1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
super(myrebnconv,self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups)
|
||||
self.bn = nn.BatchNorm2d(out_ch)
|
||||
self.rl = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self,x):
|
||||
return self.rl(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ISNetGTEncoder(nn.Module):
|
||||
|
||||
def __init__(self,in_ch=1,out_ch=1):
|
||||
super(ISNetGTEncoder,self).__init__()
|
||||
|
||||
self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
||||
|
||||
self.stage1 = RSU7(16,16,64)
|
||||
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64,16,64)
|
||||
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(64,32,128)
|
||||
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(128,32,256)
|
||||
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(256,64,512)
|
||||
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512,64,512)
|
||||
|
||||
|
||||
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
||||
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
||||
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
|
||||
def compute_loss_max(self, preds, targets, fs):
|
||||
|
||||
return muti_loss_fusion_max(preds, targets,fs)
|
||||
|
||||
def compute_loss(self, preds, targets):
|
||||
|
||||
return muti_loss_fusion(preds,targets)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.conv_in(hx)
|
||||
# hx = self.pool_in(hxin)
|
||||
|
||||
#stage 1
|
||||
hx1 = self.stage1(hxin)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
#stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
#stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
#stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
#stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
#stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
|
||||
|
||||
#side output
|
||||
d1 = self.side1(hx1)
|
||||
d1 = _upsample_like(d1,x)
|
||||
|
||||
d2 = self.side2(hx2)
|
||||
d2 = _upsample_like(d2,x)
|
||||
|
||||
d3 = self.side3(hx3)
|
||||
d3 = _upsample_like(d3,x)
|
||||
|
||||
d4 = self.side4(hx4)
|
||||
d4 = _upsample_like(d4,x)
|
||||
|
||||
d5 = self.side5(hx5)
|
||||
d5 = _upsample_like(d5,x)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6,x)
|
||||
|
||||
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||
|
||||
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
|
||||
|
||||
class ISNetDIS(nn.Module):
|
||||
|
||||
def __init__(self,in_ch=3,out_ch=1):
|
||||
super(ISNetDIS,self).__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
||||
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage1 = RSU7(64,32,64)
|
||||
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64,32,128)
|
||||
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(128,64,256)
|
||||
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(256,128,512)
|
||||
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(512,256,512)
|
||||
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512,256,512)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(1024,256,512)
|
||||
self.stage4d = RSU4(1024,128,256)
|
||||
self.stage3d = RSU5(512,64,128)
|
||||
self.stage2d = RSU6(256,32,64)
|
||||
self.stage1d = RSU7(128,16,64)
|
||||
|
||||
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
||||
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
||||
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
|
||||
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
||||
|
||||
def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
|
||||
|
||||
# return muti_loss_fusion(preds,targets)
|
||||
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
||||
|
||||
def compute_loss(self, preds, targets):
|
||||
|
||||
# return muti_loss_fusion(preds,targets)
|
||||
return muti_loss_fusion(preds, targets)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.conv_in(hx)
|
||||
hx = self.pool_in(hxin)
|
||||
|
||||
#stage 1
|
||||
hx1 = self.stage1(hxin)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
#stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
#stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
#stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
#stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
#stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6,hx5)
|
||||
|
||||
#-------------------- decoder --------------------
|
||||
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
|
||||
#side output
|
||||
d1 = self.side1(hx1d)
|
||||
d1 = _upsample_like(d1,x)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2,x)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3,x)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4,x)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5,x)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6,x)
|
||||
|
||||
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||
|
||||
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
|
92
IS-Net/pytorch18.yml
Normal file
@ -0,0 +1,92 @@
|
||||
name: pytorch18
|
||||
channels:
|
||||
- conda-forge
|
||||
- anaconda
|
||||
- pytorch
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=4.5=1_gnu
|
||||
- blas=1.0=mkl
|
||||
- brotli=1.0.9=he6710b0_2
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2022.2.1=h06a4308_0
|
||||
- certifi=2021.10.8=py37h06a4308_2
|
||||
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
||||
- colorama=0.4.4=pyhd3eb1b0_0
|
||||
- cudatoolkit=10.2.89=hfd86e86_1
|
||||
- cycler=0.11.0=pyhd3eb1b0_0
|
||||
- cytoolz=0.11.0=py37h7b6447c_0
|
||||
- dask-core=2021.10.0=pyhd3eb1b0_0
|
||||
- ffmpeg=4.3=hf484d3e_0
|
||||
- fonttools=4.25.0=pyhd3eb1b0_0
|
||||
- freetype=2.11.0=h70c0345_0
|
||||
- fsspec=2022.2.0=pyhd3eb1b0_0
|
||||
- gmp=6.2.1=h2531618_2
|
||||
- gnutls=3.6.15=he1e5248_0
|
||||
- imageio=2.9.0=pyhd3eb1b0_0
|
||||
- intel-openmp=2021.4.0=h06a4308_3561
|
||||
- jpeg=9b=h024ee3a_2
|
||||
- kiwisolver=1.3.2=py37h295c915_0
|
||||
- lame=3.100=h7b6447c_0
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- ld_impl_linux-64=2.35.1=h7274673_9
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=9.3.0=h5101ec6_17
|
||||
- libgfortran-ng=7.5.0=ha8ba4b0_17
|
||||
- libgfortran4=7.5.0=ha8ba4b0_17
|
||||
- libgomp=9.3.0=h5101ec6_17
|
||||
- libiconv=1.15=h63c8f33_5
|
||||
- libidn2=2.3.2=h7f8727e_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.2.0=h85742a9_0
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- libwebp-base=1.2.2=h7f8727e_0
|
||||
- locket=0.2.1=py37h06a4308_2
|
||||
- lz4-c=1.9.3=h295c915_1
|
||||
- matplotlib-base=3.5.1=py37ha18d171_1
|
||||
- mkl=2021.4.0=h06a4308_640
|
||||
- mkl-service=2.4.0=py37h7f8727e_0
|
||||
- mkl_fft=1.3.1=py37hd3c417c_0
|
||||
- mkl_random=1.2.2=py37h51133e4_0
|
||||
- munkres=1.1.4=py_0
|
||||
- ncurses=6.3=h7f8727e_2
|
||||
- nettle=3.7.3=hbbd107a_1
|
||||
- networkx=2.6.3=pyhd3eb1b0_0
|
||||
- ninja=1.10.2=py37hd09550d_3
|
||||
- numpy=1.21.2=py37h20f2e39_0
|
||||
- numpy-base=1.21.2=py37h79a1101_0
|
||||
- olefile=0.46=py37_0
|
||||
- openh264=2.1.1=h4ff587b_0
|
||||
- openssl=1.1.1n=h7f8727e_0
|
||||
- packaging=21.3=pyhd3eb1b0_0
|
||||
- partd=1.2.0=pyhd3eb1b0_1
|
||||
- pillow=8.0.0=py37h9a89aac_0
|
||||
- pip=21.2.2=py37h06a4308_0
|
||||
- pyparsing=3.0.4=pyhd3eb1b0_0
|
||||
- python=3.7.11=h12debd9_0
|
||||
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||
- pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
|
||||
- pywavelets=1.1.1=py37h7b6447c_2
|
||||
- pyyaml=6.0=py37h7f8727e_1
|
||||
- readline=8.1.2=h7f8727e_1
|
||||
- scikit-image=0.15.0=py37hb3f55d8_2
|
||||
- scipy=1.7.3=py37hc147768_0
|
||||
- setuptools=58.0.4=py37h06a4308_0
|
||||
- six=1.16.0=pyhd3eb1b0_1
|
||||
- sqlite=3.38.0=hc218d9a_0
|
||||
- tk=8.6.11=h1ccaba5_0
|
||||
- toolz=0.11.2=pyhd3eb1b0_0
|
||||
- torchaudio=0.8.0=py37
|
||||
- torchvision=0.9.0=py37_cu102
|
||||
- tqdm=4.63.0=pyhd8ed1ab_0
|
||||
- typing_extensions=3.10.0.2=pyh06a4308_0
|
||||
- wheel=0.37.1=pyhd3eb1b0_0
|
||||
- xz=5.2.5=h7b6447c_0
|
||||
- yaml=0.2.5=h7b6447c_0
|
||||
- zlib=1.2.11=h7f8727e_4
|
||||
- zstd=1.4.9=haebb681_0
|
||||
prefix: /home/solar/anaconda3/envs/pytorch18
|
87
IS-Net/requirements.txt
Normal file
@ -0,0 +1,87 @@
|
||||
# This file may be used to create an environment using:
|
||||
# $ conda create --name <env> --file <this file>
|
||||
# platform: linux-64
|
||||
_libgcc_mutex=0.1=main
|
||||
_openmp_mutex=4.5=1_gnu
|
||||
blas=1.0=mkl
|
||||
brotli=1.0.9=he6710b0_2
|
||||
bzip2=1.0.8=h7b6447c_0
|
||||
ca-certificates=2022.2.1=h06a4308_0
|
||||
certifi=2021.10.8=py37h06a4308_2
|
||||
cloudpickle=2.0.0=pyhd3eb1b0_0
|
||||
colorama=0.4.4=pyhd3eb1b0_0
|
||||
cudatoolkit=10.2.89=hfd86e86_1
|
||||
cycler=0.11.0=pyhd3eb1b0_0
|
||||
cytoolz=0.11.0=py37h7b6447c_0
|
||||
dask-core=2021.10.0=pyhd3eb1b0_0
|
||||
ffmpeg=4.3=hf484d3e_0
|
||||
fonttools=4.25.0=pyhd3eb1b0_0
|
||||
freetype=2.11.0=h70c0345_0
|
||||
fsspec=2022.2.0=pyhd3eb1b0_0
|
||||
gmp=6.2.1=h2531618_2
|
||||
gnutls=3.6.15=he1e5248_0
|
||||
imageio=2.9.0=pyhd3eb1b0_0
|
||||
intel-openmp=2021.4.0=h06a4308_3561
|
||||
jpeg=9b=h024ee3a_2
|
||||
kiwisolver=1.3.2=py37h295c915_0
|
||||
lame=3.100=h7b6447c_0
|
||||
lcms2=2.12=h3be6417_0
|
||||
ld_impl_linux-64=2.35.1=h7274673_9
|
||||
libffi=3.3=he6710b0_2
|
||||
libgcc-ng=9.3.0=h5101ec6_17
|
||||
libgfortran-ng=7.5.0=ha8ba4b0_17
|
||||
libgfortran4=7.5.0=ha8ba4b0_17
|
||||
libgomp=9.3.0=h5101ec6_17
|
||||
libiconv=1.15=h63c8f33_5
|
||||
libidn2=2.3.2=h7f8727e_0
|
||||
libpng=1.6.37=hbc83047_0
|
||||
libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||
libtasn1=4.16.0=h27cfd23_0
|
||||
libtiff=4.2.0=h85742a9_0
|
||||
libunistring=0.9.10=h27cfd23_0
|
||||
libuv=1.40.0=h7b6447c_0
|
||||
libwebp-base=1.2.2=h7f8727e_0
|
||||
locket=0.2.1=py37h06a4308_2
|
||||
lz4-c=1.9.3=h295c915_1
|
||||
matplotlib-base=3.5.1=py37ha18d171_1
|
||||
mkl=2021.4.0=h06a4308_640
|
||||
mkl-service=2.4.0=py37h7f8727e_0
|
||||
mkl_fft=1.3.1=py37hd3c417c_0
|
||||
mkl_random=1.2.2=py37h51133e4_0
|
||||
munkres=1.1.4=py_0
|
||||
ncurses=6.3=h7f8727e_2
|
||||
nettle=3.7.3=hbbd107a_1
|
||||
networkx=2.6.3=pyhd3eb1b0_0
|
||||
ninja=1.10.2=py37hd09550d_3
|
||||
numpy=1.21.2=py37h20f2e39_0
|
||||
numpy-base=1.21.2=py37h79a1101_0
|
||||
olefile=0.46=py37_0
|
||||
openh264=2.1.1=h4ff587b_0
|
||||
openssl=1.1.1n=h7f8727e_0
|
||||
packaging=21.3=pyhd3eb1b0_0
|
||||
partd=1.2.0=pyhd3eb1b0_1
|
||||
pillow=8.0.0=py37h9a89aac_0
|
||||
pip=21.2.2=py37h06a4308_0
|
||||
pyparsing=3.0.4=pyhd3eb1b0_0
|
||||
python=3.7.11=h12debd9_0
|
||||
python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||
pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
|
||||
pywavelets=1.1.1=py37h7b6447c_2
|
||||
pyyaml=6.0=py37h7f8727e_1
|
||||
readline=8.1.2=h7f8727e_1
|
||||
scikit-image=0.15.0=py37hb3f55d8_2
|
||||
scipy=1.7.3=py37hc147768_0
|
||||
setuptools=58.0.4=py37h06a4308_0
|
||||
six=1.16.0=pyhd3eb1b0_1
|
||||
sqlite=3.38.0=hc218d9a_0
|
||||
tk=8.6.11=h1ccaba5_0
|
||||
toolz=0.11.2=pyhd3eb1b0_0
|
||||
torchaudio=0.8.0=py37
|
||||
torchvision=0.9.0=py37_cu102
|
||||
tqdm=4.63.0=pyhd8ed1ab_0
|
||||
typing_extensions=3.10.0.2=pyh06a4308_0
|
||||
wheel=0.37.1=pyhd3eb1b0_0
|
||||
xz=5.2.5=h7b6447c_0
|
||||
yaml=0.2.5=h7b6447c_0
|
||||
zlib=1.2.11=h7f8727e_4
|
||||
zstd=1.4.9=haebb681_0
|
713
IS-Net/train_valid_inference_main.py
Normal file
@ -0,0 +1,713 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from skimage import io
|
||||
import time
|
||||
|
||||
import torch, gc
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
|
||||
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
|
||||
from models import *
|
||||
|
||||
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
||||
|
||||
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
||||
# cache_size = hypar["cache_size"],
|
||||
# cache_boost = hypar["cache_boost_train"],
|
||||
# my_transforms = [
|
||||
# GOSRandomHFlip(),
|
||||
# # GOSResize(hypar["input_size"]),
|
||||
# # GOSRandomCrop(hypar["crop_size"]),
|
||||
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
||||
# ],
|
||||
# batch_size = hypar["batch_size_train"],
|
||||
# shuffle = True)
|
||||
|
||||
torch.manual_seed(hypar["seed"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(hypar["seed"])
|
||||
|
||||
print("define gt encoder ...")
|
||||
net = ISNetGTEncoder() #UNETGTENCODERCombine()
|
||||
## load the existing model gt encoder
|
||||
if(hypar["gt_encoder_model"]!=""):
|
||||
model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"]
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load(model_path))
|
||||
net.cuda()
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
||||
print("gt encoder restored from the saved weights ...")
|
||||
return net ############
|
||||
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
|
||||
print("--- define optimizer for GT Encoder---")
|
||||
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||
|
||||
model_path = hypar["model_path"]
|
||||
model_save_fre = hypar["model_save_fre"]
|
||||
max_ite = hypar["max_ite"]
|
||||
batch_size_train = hypar["batch_size_train"]
|
||||
batch_size_valid = hypar["batch_size_valid"]
|
||||
|
||||
if(not os.path.exists(model_path)):
|
||||
os.mkdir(model_path)
|
||||
|
||||
ite_num = hypar["start_ite"] # count the total iteration number
|
||||
ite_num4val = 0 #
|
||||
running_loss = 0.0 # count the toal loss
|
||||
running_tar_loss = 0.0 # count the target output loss
|
||||
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
||||
|
||||
train_num = train_datasets[0].__len__()
|
||||
|
||||
net.train()
|
||||
|
||||
start_last = time.time()
|
||||
gos_dataloader = train_dataloaders[0]
|
||||
epoch_num = hypar["max_epoch_num"]
|
||||
notgood_cnt = 0
|
||||
for epoch in range(epoch_num): ## set the epoch num as 100000
|
||||
|
||||
for i, data in enumerate(gos_dataloader):
|
||||
|
||||
if(ite_num >= max_ite):
|
||||
print("Training Reached the Maximal Iteration Number ", max_ite)
|
||||
exit()
|
||||
|
||||
# start_read = time.time()
|
||||
ite_num = ite_num + 1
|
||||
ite_num4val = ite_num4val + 1
|
||||
|
||||
# get the inputs
|
||||
labels = data['label']
|
||||
|
||||
if(hypar["model_digit"]=="full"):
|
||||
labels = labels.type(torch.FloatTensor)
|
||||
else:
|
||||
labels = labels.type(torch.HalfTensor)
|
||||
|
||||
# wrap them in Variable
|
||||
if torch.cuda.is_available():
|
||||
labels_v = Variable(labels.cuda(), requires_grad=False)
|
||||
else:
|
||||
labels_v = Variable(labels, requires_grad=False)
|
||||
|
||||
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
||||
|
||||
# y zero the parameter gradients
|
||||
start_inf_loss_back = time.time()
|
||||
optimizer.zero_grad()
|
||||
|
||||
ds, fs = net(labels_v)#net(inputs_v)
|
||||
loss2, loss = net.compute_loss(ds, labels_v)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
running_tar_loss += loss2.item()
|
||||
|
||||
# del outputs, loss
|
||||
del ds, loss2, loss
|
||||
end_inf_loss_back = time.time()-start_inf_loss_back
|
||||
|
||||
print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
||||
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
||||
start_last = time.time()
|
||||
|
||||
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
||||
notgood_cnt += 1
|
||||
# net.eval()
|
||||
# tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
||||
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch)
|
||||
|
||||
net.train() # resume train
|
||||
|
||||
tmp_out = 0
|
||||
print("last_f1:",last_f1)
|
||||
print("tmp_f1:",tmp_f1)
|
||||
for fi in range(len(last_f1)):
|
||||
if(tmp_f1[fi]>last_f1[fi]):
|
||||
tmp_out = 1
|
||||
print("tmp_out:",tmp_out)
|
||||
if(tmp_out):
|
||||
notgood_cnt = 0
|
||||
last_f1 = tmp_f1
|
||||
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
||||
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
||||
maxf1 = '_'.join(tmp_f1_str)
|
||||
meanM = '_'.join(tmp_mae_str)
|
||||
# .cpu().detach().numpy()
|
||||
model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\
|
||||
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
||||
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
||||
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
||||
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
||||
"_maxF1_" + maxf1 + \
|
||||
"_mae_" + meanM + \
|
||||
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
||||
torch.save(net.state_dict(), model_path + model_name)
|
||||
|
||||
running_loss = 0.0
|
||||
running_tar_loss = 0.0
|
||||
ite_num4val = 0
|
||||
|
||||
if(tmp_f1[0]>0.99):
|
||||
print("GT encoder is well-trained and obtained...")
|
||||
return net
|
||||
|
||||
if(notgood_cnt >= hypar["early_stop"]):
|
||||
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
||||
exit()
|
||||
|
||||
print("Training Reaches The Maximum Epoch Number")
|
||||
return net
|
||||
|
||||
def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||
net.eval()
|
||||
print("Validating...")
|
||||
epoch_num = hypar["max_epoch_num"]
|
||||
|
||||
val_loss = 0.0
|
||||
tar_loss = 0.0
|
||||
|
||||
|
||||
tmp_f1 = []
|
||||
tmp_mae = []
|
||||
tmp_time = []
|
||||
|
||||
start_valid = time.time()
|
||||
for k in range(len(valid_dataloaders)):
|
||||
|
||||
valid_dataloader = valid_dataloaders[k]
|
||||
valid_dataset = valid_datasets[k]
|
||||
|
||||
val_num = valid_dataset.__len__()
|
||||
mybins = np.arange(0,256)
|
||||
PRE = np.zeros((val_num,len(mybins)-1))
|
||||
REC = np.zeros((val_num,len(mybins)-1))
|
||||
F1 = np.zeros((val_num,len(mybins)-1))
|
||||
MAE = np.zeros((val_num))
|
||||
|
||||
val_cnt = 0.0
|
||||
for i_val, data_val in enumerate(valid_dataloader):
|
||||
|
||||
# imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
||||
imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape']
|
||||
|
||||
if(hypar["model_digit"]=="full"):
|
||||
labels_val = labels_val.type(torch.FloatTensor)
|
||||
else:
|
||||
labels_val = labels_val.type(torch.HalfTensor)
|
||||
|
||||
# wrap them in Variable
|
||||
if torch.cuda.is_available():
|
||||
labels_val_v = Variable(labels_val.cuda(), requires_grad=False)
|
||||
else:
|
||||
labels_val_v = Variable(labels_val,requires_grad=False)
|
||||
|
||||
t_start = time.time()
|
||||
ds_val = net(labels_val_v)[0]
|
||||
t_end = time.time()-t_start
|
||||
tmp_time.append(t_end)
|
||||
|
||||
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
||||
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
||||
|
||||
# compute F measure
|
||||
for t in range(hypar["batch_size_valid"]):
|
||||
val_cnt = val_cnt + 1.0
|
||||
print("num of val: ", val_cnt)
|
||||
i_test = imidx_val[t].data.numpy()
|
||||
|
||||
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
|
||||
|
||||
## recover the prediction spatial size to the orignal image size
|
||||
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
||||
|
||||
ma = torch.max(pred_val)
|
||||
mi = torch.min(pred_val)
|
||||
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
||||
# pred_val = normPRED(pred_val)
|
||||
|
||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||
with torch.no_grad():
|
||||
gt = torch.tensor(gt).cuda()
|
||||
|
||||
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
||||
|
||||
PRE[i_test,:]=pre
|
||||
REC[i_test,:] = rec
|
||||
F1[i_test,:] = f1
|
||||
MAE[i_test] = mae
|
||||
|
||||
del ds_val, gt
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# if(loss_val.data[0]>1):
|
||||
val_loss += loss_val.item()#data[0]
|
||||
tar_loss += loss2_val.item()#data[0]
|
||||
|
||||
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
||||
|
||||
del loss2_val, loss_val
|
||||
|
||||
print('============================')
|
||||
PRE_m = np.mean(PRE,0)
|
||||
REC_m = np.mean(REC,0)
|
||||
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
||||
# print('--------------:', np.mean(f1_m))
|
||||
tmp_f1.append(np.amax(f1_m))
|
||||
tmp_mae.append(np.mean(MAE))
|
||||
print("The max F1 Score: %f"%(np.max(f1_m)))
|
||||
print("MAE: ", np.mean(MAE))
|
||||
|
||||
# print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
|
||||
|
||||
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
||||
|
||||
def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
||||
|
||||
if hypar["interm_sup"]:
|
||||
print("Get the gt encoder ...")
|
||||
featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val)
|
||||
## freeze the weights of gt encoder
|
||||
for param in featurenet.parameters():
|
||||
param.requires_grad=False
|
||||
|
||||
|
||||
model_path = hypar["model_path"]
|
||||
model_save_fre = hypar["model_save_fre"]
|
||||
max_ite = hypar["max_ite"]
|
||||
batch_size_train = hypar["batch_size_train"]
|
||||
batch_size_valid = hypar["batch_size_valid"]
|
||||
|
||||
if(not os.path.exists(model_path)):
|
||||
os.mkdir(model_path)
|
||||
|
||||
ite_num = hypar["start_ite"] # count the toal iteration number
|
||||
ite_num4val = 0 #
|
||||
running_loss = 0.0 # count the toal loss
|
||||
running_tar_loss = 0.0 # count the target output loss
|
||||
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
||||
|
||||
train_num = train_datasets[0].__len__()
|
||||
|
||||
net.train()
|
||||
|
||||
start_last = time.time()
|
||||
gos_dataloader = train_dataloaders[0]
|
||||
epoch_num = hypar["max_epoch_num"]
|
||||
notgood_cnt = 0
|
||||
for epoch in range(epoch_num): ## set the epoch num as 100000
|
||||
|
||||
for i, data in enumerate(gos_dataloader):
|
||||
|
||||
if(ite_num >= max_ite):
|
||||
print("Training Reached the Maximal Iteration Number ", max_ite)
|
||||
exit()
|
||||
|
||||
# start_read = time.time()
|
||||
ite_num = ite_num + 1
|
||||
ite_num4val = ite_num4val + 1
|
||||
|
||||
# get the inputs
|
||||
inputs, labels = data['image'], data['label']
|
||||
|
||||
if(hypar["model_digit"]=="full"):
|
||||
inputs = inputs.type(torch.FloatTensor)
|
||||
labels = labels.type(torch.FloatTensor)
|
||||
else:
|
||||
inputs = inputs.type(torch.HalfTensor)
|
||||
labels = labels.type(torch.HalfTensor)
|
||||
|
||||
# wrap them in Variable
|
||||
if torch.cuda.is_available():
|
||||
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
|
||||
else:
|
||||
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
|
||||
|
||||
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
||||
|
||||
# y zero the parameter gradients
|
||||
start_inf_loss_back = time.time()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if hypar["interm_sup"]:
|
||||
# forward + backward + optimize
|
||||
ds,dfs = net(inputs_v)
|
||||
_,fs = featurenet(labels_v) ## extract the gt encodings
|
||||
loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE')
|
||||
else:
|
||||
# forward + backward + optimize
|
||||
ds,_ = net(inputs_v)
|
||||
loss2, loss = muti_loss_fusion(ds, labels_v)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# # print statistics
|
||||
running_loss += loss.item()
|
||||
running_tar_loss += loss2.item()
|
||||
|
||||
# del outputs, loss
|
||||
del ds, loss2, loss
|
||||
end_inf_loss_back = time.time()-start_inf_loss_back
|
||||
|
||||
print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
||||
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
||||
start_last = time.time()
|
||||
|
||||
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
||||
notgood_cnt += 1
|
||||
net.eval()
|
||||
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
||||
net.train() # resume train
|
||||
|
||||
tmp_out = 0
|
||||
print("last_f1:",last_f1)
|
||||
print("tmp_f1:",tmp_f1)
|
||||
for fi in range(len(last_f1)):
|
||||
if(tmp_f1[fi]>last_f1[fi]):
|
||||
tmp_out = 1
|
||||
print("tmp_out:",tmp_out)
|
||||
if(tmp_out):
|
||||
notgood_cnt = 0
|
||||
last_f1 = tmp_f1
|
||||
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
||||
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
||||
maxf1 = '_'.join(tmp_f1_str)
|
||||
meanM = '_'.join(tmp_mae_str)
|
||||
# .cpu().detach().numpy()
|
||||
model_name = "/gpu_itr_"+str(ite_num)+\
|
||||
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
||||
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
||||
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
||||
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
||||
"_maxF1_" + maxf1 + \
|
||||
"_mae_" + meanM + \
|
||||
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
||||
torch.save(net.state_dict(), model_path + model_name)
|
||||
|
||||
running_loss = 0.0
|
||||
running_tar_loss = 0.0
|
||||
ite_num4val = 0
|
||||
|
||||
if(notgood_cnt >= hypar["early_stop"]):
|
||||
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
||||
exit()
|
||||
|
||||
print("Training Reaches The Maximum Epoch Number")
|
||||
|
||||
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
||||
net.eval()
|
||||
print("Validating...")
|
||||
epoch_num = hypar["max_epoch_num"]
|
||||
|
||||
val_loss = 0.0
|
||||
tar_loss = 0.0
|
||||
val_cnt = 0.0
|
||||
|
||||
tmp_f1 = []
|
||||
tmp_mae = []
|
||||
tmp_time = []
|
||||
|
||||
start_valid = time.time()
|
||||
|
||||
for k in range(len(valid_dataloaders)):
|
||||
|
||||
valid_dataloader = valid_dataloaders[k]
|
||||
valid_dataset = valid_datasets[k]
|
||||
|
||||
val_num = valid_dataset.__len__()
|
||||
mybins = np.arange(0,256)
|
||||
PRE = np.zeros((val_num,len(mybins)-1))
|
||||
REC = np.zeros((val_num,len(mybins)-1))
|
||||
F1 = np.zeros((val_num,len(mybins)-1))
|
||||
MAE = np.zeros((val_num))
|
||||
|
||||
for i_val, data_val in enumerate(valid_dataloader):
|
||||
val_cnt = val_cnt + 1.0
|
||||
imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
||||
|
||||
if(hypar["model_digit"]=="full"):
|
||||
inputs_val = inputs_val.type(torch.FloatTensor)
|
||||
labels_val = labels_val.type(torch.FloatTensor)
|
||||
else:
|
||||
inputs_val = inputs_val.type(torch.HalfTensor)
|
||||
labels_val = labels_val.type(torch.HalfTensor)
|
||||
|
||||
# wrap them in Variable
|
||||
if torch.cuda.is_available():
|
||||
inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False)
|
||||
else:
|
||||
inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False)
|
||||
|
||||
t_start = time.time()
|
||||
ds_val = net(inputs_val_v)[0]
|
||||
t_end = time.time()-t_start
|
||||
tmp_time.append(t_end)
|
||||
|
||||
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
||||
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
||||
|
||||
# compute F measure
|
||||
for t in range(hypar["batch_size_valid"]):
|
||||
i_test = imidx_val[t].data.numpy()
|
||||
|
||||
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
|
||||
|
||||
## recover the prediction spatial size to the orignal image size
|
||||
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
||||
|
||||
# pred_val = normPRED(pred_val)
|
||||
ma = torch.max(pred_val)
|
||||
mi = torch.min(pred_val)
|
||||
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
||||
|
||||
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
||||
with torch.no_grad():
|
||||
gt = torch.tensor(gt).cuda()
|
||||
|
||||
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
||||
|
||||
|
||||
PRE[i_test,:]=pre
|
||||
REC[i_test,:] = rec
|
||||
F1[i_test,:] = f1
|
||||
MAE[i_test] = mae
|
||||
|
||||
del ds_val, gt
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# if(loss_val.data[0]>1):
|
||||
val_loss += loss_val.item()#data[0]
|
||||
tar_loss += loss2_val.item()#data[0]
|
||||
|
||||
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
||||
|
||||
del loss2_val, loss_val
|
||||
|
||||
print('============================')
|
||||
PRE_m = np.mean(PRE,0)
|
||||
REC_m = np.mean(REC,0)
|
||||
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
||||
|
||||
tmp_f1.append(np.amax(f1_m))
|
||||
tmp_mae.append(np.mean(MAE))
|
||||
|
||||
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
||||
|
||||
def main(train_datasets,
|
||||
valid_datasets,
|
||||
hypar): # model: "train", "test"
|
||||
|
||||
### --- Step 1: Build datasets and dataloaders ---
|
||||
dataloaders_train = []
|
||||
dataloaders_valid = []
|
||||
|
||||
if(hypar["mode"]=="train"):
|
||||
print("--- create training dataloader ---")
|
||||
## collect training dataset
|
||||
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
|
||||
## build dataloader for training datasets
|
||||
train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
||||
cache_size = hypar["cache_size"],
|
||||
cache_boost = hypar["cache_boost_train"],
|
||||
my_transforms = [
|
||||
GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
|
||||
# GOSResize(hypar["input_size"]),
|
||||
# GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
|
||||
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
||||
],
|
||||
batch_size = hypar["batch_size_train"],
|
||||
shuffle = True)
|
||||
train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list,
|
||||
cache_size = hypar["cache_size"],
|
||||
cache_boost = hypar["cache_boost_train"],
|
||||
my_transforms = [
|
||||
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
||||
],
|
||||
batch_size = hypar["batch_size_valid"],
|
||||
shuffle = False)
|
||||
print(len(train_dataloaders), " train dataloaders created")
|
||||
|
||||
print("--- create valid dataloader ---")
|
||||
## build dataloader for validation or testing
|
||||
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
|
||||
## build dataloader for training datasets
|
||||
valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list,
|
||||
cache_size = hypar["cache_size"],
|
||||
cache_boost = hypar["cache_boost_valid"],
|
||||
my_transforms = [
|
||||
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
||||
# GOSResize(hypar["input_size"])
|
||||
],
|
||||
batch_size=hypar["batch_size_valid"],
|
||||
shuffle=False)
|
||||
print(len(valid_dataloaders), " valid dataloaders created")
|
||||
# print(valid_datasets[0]["data_name"])
|
||||
|
||||
### --- Step 2: Build Model and Optimizer ---
|
||||
print("--- build model ---")
|
||||
net = hypar["model"]#GOSNETINC(3,1)
|
||||
|
||||
# convert to half precision
|
||||
if(hypar["model_digit"]=="half"):
|
||||
net.half()
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.BatchNorm2d):
|
||||
layer.float()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
|
||||
if(hypar["restore_model"]!=""):
|
||||
print("restore model from:")
|
||||
print(hypar["model_path"]+"/"+hypar["restore_model"])
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]))
|
||||
else:
|
||||
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"))
|
||||
|
||||
print("--- define optimizer ---")
|
||||
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||
|
||||
### --- Step 3: Train or Valid Model ---
|
||||
if(hypar["mode"]=="train"):
|
||||
train(net,
|
||||
optimizer,
|
||||
train_dataloaders,
|
||||
train_datasets,
|
||||
valid_dataloaders,
|
||||
valid_datasets,
|
||||
hypar,
|
||||
train_dataloaders_val, train_datasets_val)
|
||||
else:
|
||||
valid(net,
|
||||
valid_dataloaders,
|
||||
valid_datasets,
|
||||
hypar)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
|
||||
## configure the train, valid and inference datasets
|
||||
train_datasets, valid_datasets = [], []
|
||||
dataset_1, dataset_1 = {}, {}
|
||||
|
||||
dataset_tr = {"name": "DIS5K-TR",
|
||||
"im_dir": "../DIS5K/DIS-TR/im",
|
||||
"gt_dir": "../DIS5K/DIS-TR/gt",
|
||||
"im_ext": ".jpg",
|
||||
"gt_ext": ".png",
|
||||
"cache_dir":"../DIS5K-Cache/DIS-TR"}
|
||||
|
||||
dataset_vd = {"name": "DIS5K-VD",
|
||||
"im_dir": "../DIS5K/DIS-VD/im",
|
||||
"gt_dir": "../DIS5K/DIS-VD/gt",
|
||||
"im_ext": ".jpg",
|
||||
"gt_ext": ".png",
|
||||
"cache_dir":"../DIS5K-Cache/DIS-VD"}
|
||||
|
||||
dataset_te1 = {"name": "DIS5K-TE1",
|
||||
"im_dir": "../DIS5K/DIS-TE1/im",
|
||||
"gt_dir": "../DIS5K/DIS-TE1/gt",
|
||||
"im_ext": ".jpg",
|
||||
"gt_ext": ".png",
|
||||
"cache_dir":"../DIS5K-Cache/DIS-TE1"}
|
||||
|
||||
dataset_te2 = {"name": "DIS5K-TE2",
|
||||
"im_dir": "../DIS5K/DIS-TE2/im",
|
||||
"gt_dir": "../DIS5K/DIS-TE2/gt",
|
||||
"im_ext": ".jpg",
|
||||
"gt_ext": ".png",
|
||||
"cache_dir":"../DIS5K-Cache/DIS-TE2"}
|
||||
|
||||
dataset_te3 = {"name": "DIS5K-TE3",
|
||||
"im_dir": "../DIS5K/DIS-TE3/im",
|
||||
"gt_dir": "../DIS5K/DIS-TE3/gt",
|
||||
"im_ext": ".jpg",
|
||||
"gt_ext": ".png",
|
||||
"cache_dir":"../DIS5K-Cache/DIS-TE3"}
|
||||
|
||||
dataset_te4 = {"name": "DIS5K-TE4",
|
||||
"im_dir": "../DIS5K/DIS-TE4/im",
|
||||
"gt_dir": "../DIS5K/DIS-TE4/gt",
|
||||
"im_ext": ".jpg",
|
||||
"gt_ext": ".png",
|
||||
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
|
||||
|
||||
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
|
||||
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
|
||||
valid_datasets = [dataset_vd] #, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
|
||||
|
||||
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
|
||||
hypar = {}
|
||||
|
||||
## -- 2.1. configure the model saving or restoring path --
|
||||
hypar["mode"] = "train"
|
||||
## "train": for training,
|
||||
## "valid": for validation and inferening,
|
||||
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
|
||||
## otherwise only accuracy will be calculated and no predictions will be saved
|
||||
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
|
||||
|
||||
if hypar["mode"] == "train":
|
||||
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
|
||||
hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
|
||||
hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
|
||||
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
|
||||
hypar["gt_encoder_model"] = ""
|
||||
else: ## configure the segmentation output path and the to-be-used model weights path
|
||||
hypar["valid_out_dir"] = "../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
||||
hypar["model_path"] ="../saved_models/IS-Net" ## load trained weights from this path
|
||||
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
|
||||
|
||||
# if hypar["restore_model"]!="":
|
||||
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
|
||||
|
||||
## -- 2.2. choose floating point accuracy --
|
||||
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
|
||||
hypar["seed"] = 0
|
||||
|
||||
## -- 2.3. cache data spatial size --
|
||||
## To handle large size input images, which take a lot of time for loading in training,
|
||||
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
|
||||
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
|
||||
hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
|
||||
hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
|
||||
|
||||
## --- 2.4. data augmentation parameters ---
|
||||
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
|
||||
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
|
||||
hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the dataloader and it is not in use
|
||||
hypar["random_flip_v"] = 0 ## vertical flip , currently not in use
|
||||
|
||||
## --- 2.5. define model ---
|
||||
print("building model...")
|
||||
hypar["model"] = ISNetDIS() #U2NETFASTFEATURESUP()
|
||||
hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
|
||||
hypar["model_save_fre"] = 2000 ## valid and save model weights every 2000 iterations
|
||||
|
||||
hypar["batch_size_train"] = 8 ## batch size for training
|
||||
hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing
|
||||
print("batch size: ", hypar["batch_size_train"])
|
||||
|
||||
hypar["max_ite"] = 10000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
|
||||
hypar["max_epoch_num"] = 1000000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
|
||||
|
||||
main(train_datasets,
|
||||
valid_datasets,
|
||||
hypar=hypar)
|
161
README.md
@ -2,27 +2,147 @@
|
||||
<img width="420" height="320" src="figures/dis-logo-official.png">
|
||||
</p>
|
||||
|
||||
![dis5k-v1-sailship](figures/dis5k-v1-sailship.jpeg)
|
||||
|
||||
<br>
|
||||
|
||||
## [Highly Accurate Dichotomous Image Segmentation (ECCV 2022)](https://arxiv.org/pdf/2203.03041.pdf)
|
||||
[Xuebin Qin](https://xuebinqin.github.io/), [Hang Dai](https://scholar.google.co.uk/citations?user=6yvjpQQAAAAJ&hl=en), [Xiaobin Hu](https://scholar.google.de/citations?user=3lMuodUAAAAJ&hl=en), [Deng-Ping Fan*](https://dengpingfan.github.io/), [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en) and [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en).
|
||||
[Xuebin Qin](https://xuebinqin.github.io/), [Hang Dai](https://scholar.google.co.uk/citations?user=6yvjpQQAAAAJ&hl=en), [Xiaobin Hu](https://scholar.google.de/citations?user=3lMuodUAAAAJ&hl=en), [Deng-Ping Fan*](https://dengpingfan.github.io/), [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en), [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en).
|
||||
|
||||
<br>
|
||||
|
||||
This is the official repo for our new project DIS:
|
||||
[**[Project Page]**](https://xuebinqin.github.io/dis/index.html)
|
||||
## This is the official repo for our newly formulated DIS task:
|
||||
[**Project Page**](https://xuebinqin.github.io/dis/index.html), [**Arxiv**](https://arxiv.org/pdf/2203.03041.pdf).
|
||||
|
||||
<br>
|
||||
|
||||
## Updates !!!
|
||||
|
||||
<br>
|
||||
|
||||
## ** (2022-Jul.-17)** Our paper, code and dataset are now officially released!!! Please check our project page for more details: [**Project Page**](https://xuebinqin.github.io/dis/index.html).<br>
|
||||
** (2022-Jul.-5)** Our DIS work is now accepted by ECCV 2022, the code and dataset will be released before July 17th, 2022. Please be aware of our updates.
|
||||
|
||||
## DEMO
|
||||
![ship-demo](figures/ship-demo.gif)
|
||||
![bg-removal](figures/bg-removal.gif)
|
||||
![view-move](figures/view-move.gif)
|
||||
![motor-demo](figures/motor-demo.gif)
|
||||
<br>
|
||||
|
||||
## Comparisons Against SOTAs
|
||||
## 1. [Our DIS5K Dataset V1.0 (Version Alias: DIS5K Sailship)](https://xuebinqin.github.io/dis/index.html)
|
||||
|
||||
<br>
|
||||
|
||||
### Download: [Google Drive](https://drive.google.com/file/d/1jOC2zK0GowBvEt03B7dGugCRDVAoeIqq/view?usp=sharing) or [Baidu Pan 提取码:rtgw](https://pan.baidu.com/s/1y6CQJYledfYyEO0C_Gejpw?pwd=rtgw)
|
||||
|
||||
![dis5k-dataset-v1-sailship](figures/DIS5k-dataset-v1-sailship.png)
|
||||
![complexities-qual](figures/complexities-qual.jpeg)
|
||||
![categories](figures/categories.jpeg)
|
||||
|
||||
<br>
|
||||
|
||||
## 2. APPLICATIONS of Our DIS5K Dataset
|
||||
|
||||
<br>
|
||||
|
||||
### 3D Modeling
|
||||
![3d-modeling](figures/3d-modeling.png)
|
||||
|
||||
### Image Editing
|
||||
![ship-demo](figures/ship-demo.gif)
|
||||
### Art Design Materials
|
||||
![bg-removal](figures/bg-removal.gif)
|
||||
### Still Image Animation
|
||||
![view-move](figures/view-move.gif)
|
||||
### AR
|
||||
![motor-demo](figures/motor-demo.gif)
|
||||
### 3D Rendering
|
||||
![video-3d](figures/video-3d.gif)
|
||||
|
||||
<br>
|
||||
|
||||
## 3. Architecture of Our IS-Net
|
||||
|
||||
<br>
|
||||
|
||||
![is-net](figures/is-net.png)
|
||||
|
||||
<br>
|
||||
|
||||
## 4. Human Correction Efforts (HCE)
|
||||
|
||||
<br>
|
||||
|
||||
![hce-metric](figures/hce-metric.png)
|
||||
|
||||
<br>
|
||||
|
||||
## 5. Experimental Results
|
||||
|
||||
<br>
|
||||
|
||||
### Predicted Maps, [(Google Drive)](https://drive.google.com/file/d/1FMtDLFrL6xVc41eKlLnuZWMBAErnKv0Y/view?usp=sharing), [(Baidu Pan 提取码:ph1d)](https://pan.baidu.com/s/1WUk2RYYpii2xzrvLna9Fsg?pwd=ph1d), of Our IS-Net and Other SOTAs
|
||||
|
||||
### Qualitative Comparisons Against SOTAs
|
||||
![qual-comp](figures/qual-comp.jpg)
|
||||
|
||||
### Quantitative Comparisons Against SOTAs
|
||||
![qual-comp](figures/quan-comp.png)
|
||||
|
||||
<br>
|
||||
|
||||
## 6. Run Our Code
|
||||
|
||||
<br>
|
||||
|
||||
### (1) Clone this repo
|
||||
```
|
||||
git clone https://github.com/xuebinqin/DIS.git
|
||||
```
|
||||
|
||||
### (2) Configuring the environment: go to the root ```DIS``` folder and run
|
||||
```
|
||||
conda env create -f pytorch18.yml
|
||||
```
|
||||
Or you can check the ```requirements.txt``` to configure the dependancies.
|
||||
|
||||
### (3) Train:
|
||||
#### (a) Open ```train_valid_inference_main.py```, set the path of your to-be-inferenced ```train_datasets``` and ```valid_datasets```, e.g., ```valid_datasets=[dataset_vd]```
|
||||
#### (b) Set the ```hypar["mode"]``` to ```"train"```
|
||||
#### (c) Create a new folder ```your_model_weights``` in the directory ```saved_models``` and set it as the ```hypar["model_path"] ="../saved_models/your_model_weights"``` and make sure ```hypar["valid_out_dir"]```(line 668) is set to ```""```, otherwise the prediction maps of the validation stage will be saved to that directory, which will slow the training speed down
|
||||
#### (d) Run
|
||||
```
|
||||
python train_valid_inference_main.py
|
||||
```
|
||||
|
||||
|
||||
### (4) Inference
|
||||
#### (a). Download the pre-trained weights (for fair academic comparisons only, the optimized model for engineering or common use will be released soon) ```isnet.pth``` from [(Google Drive)](https://drive.google.com/file/d/1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn/view?usp=sharing) or [(Baidu Pan 提取码:xbfk)](https://pan.baidu.com/s/1-X2WutiBkWPt-oakuvZ10w?pwd=xbfk) and store ```isnet.pth``` in ```saved_models/IS-Net```
|
||||
#### (b) Open ```train_valid_inference_main.py```, set the path of your to-be-inferenced ```valid_datasets```, e.g., ```valid_datasets=[dataset_te1, dataset_te2, dataset_te3, dataset_te4]```
|
||||
#### (c) Set the ```hypar["mode"]``` to ```"valid"```
|
||||
#### (d) Set the output directory of your predicted maps, e.g., ```hypar["valid_out_dir"] = "../DIS5K-Results-test"```
|
||||
#### (e) Run
|
||||
```
|
||||
python train_valid_inference_main.py
|
||||
```
|
||||
|
||||
### (5) Use of our Human Correction Efforts(HCE) metric, set the ground truth directory ```gt_root``` and the prediction directory ```pred_root```. To reduce the time costs for computing HCE, the skeletion of the DIS5K dataset can be pre-computed and stored in ```gt_ske_root```. If ```gt_ske_root=""```, the HCE code will compute the skeleton online which usually takes a lot for time for large size ground truth. Then, run ```python hce_metric_main.py```. Other metrics are evaluated based on the [SOCToolbox](https://github.com/mczhuge/SOCToolbox).
|
||||
|
||||
<br>
|
||||
|
||||
## 7. Term of Use
|
||||
Our code and evaluation metric use Apache License 2.0. The Terms of use for our DIS5K dataset is provided as [DIS5K-Dataset-Terms-of-Use.pdf](DIS5K-Dataset-Terms-of-Use.pdf).
|
||||
|
||||
<br>
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
<br>
|
||||
|
||||
We would like to thank Dr. [Ibrahim Almakky](https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en) for his helps in implementing the dataloader cache machanism of loading large-size training samples and Jiayi Zhu for his efforts in re-organizing our code and dataset.
|
||||
|
||||
<br>
|
||||
|
||||
## Citation
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
@InProceedings{qin2022,
|
||||
author={Xuebin Qin and Hang Dai and Xiaobin Hu and Deng-Ping Fan and Ling Shao and Luc Van Gool},
|
||||
@ -31,8 +151,15 @@ This is the official repo for our new project DIS:
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
## Our Previous Works
|
||||
|
||||
<br>
|
||||
|
||||
## Our Previous Works: [U<sup>2</sup>-Net](https://github.com/xuebinqin/U-2-Net), [BASNet](https://github.com/xuebinqin/BASNet).
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
|
||||
@InProceedings{Qin_2020_PR,
|
||||
title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
|
||||
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin},
|
||||
@ -42,13 +169,6 @@ This is the official repo for our new project DIS:
|
||||
year = {2020}
|
||||
}
|
||||
|
||||
@article{qin2021boundary,
|
||||
title={Boundary-aware segmentation network for mobile and web applications},
|
||||
author={Qin, Xuebin and Fan, Deng-Ping and Huang, Chenyang and Diagne, Cyril and Zhang, Zichen and Sant'Anna, Adri{\`a} Cabeza and Suarez, Albert and Jagersand, Martin and Shao, Ling},
|
||||
journal={arXiv preprint arXiv:2101.04704},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
@InProceedings{Qin_2019_CVPR,
|
||||
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Gao, Chao and Dehghan, Masood and Jagersand, Martin},
|
||||
title = {BASNet: Boundary-Aware Salient Object Detection},
|
||||
@ -56,3 +176,10 @@ This is the official repo for our new project DIS:
|
||||
month = {June},
|
||||
year = {2019}
|
||||
}
|
||||
|
||||
@article{qin2021boundary,
|
||||
title={Boundary-aware segmentation network for mobile and web applications},
|
||||
author={Qin, Xuebin and Fan, Deng-Ping and Huang, Chenyang and Diagne, Cyril and Zhang, Zichen and Sant'Anna, Adri{\`a} Cabeza and Suarez, Albert and Jagersand, Martin and Shao, Ling},
|
||||
journal={arXiv preprint arXiv:2101.04704},
|
||||
year={2021}
|
||||
}
|
BIN
figures/.DS_Store
vendored
Normal file
BIN
figures/3d-modeling.png
Normal file
After Width: | Height: | Size: 1.1 MiB |
BIN
figures/DIS5k-dataset-v1-sailship.png
Normal file
After Width: | Height: | Size: 1.0 MiB |
BIN
figures/categories.jpeg
Normal file
After Width: | Height: | Size: 928 KiB |
BIN
figures/complexities-qual.jpeg
Normal file
After Width: | Height: | Size: 667 KiB |
BIN
figures/dis5k-v1-sailship.jpeg
Normal file
After Width: | Height: | Size: 297 KiB |
BIN
figures/hce-metric.png
Normal file
After Width: | Height: | Size: 137 KiB |
BIN
figures/is-net.png
Normal file
After Width: | Height: | Size: 912 KiB |
BIN
figures/quan-comp.png
Normal file
After Width: | Height: | Size: 887 KiB |
BIN
figures/statistics.jpeg
Normal file
After Width: | Height: | Size: 525 KiB |
BIN
figures/video-3d.gif
Normal file
After Width: | Height: | Size: 14 MiB |