official release of our isnet and dis5k

This commit is contained in:
Xuebin Qin 2022-07-16 22:56:37 -07:00
parent 538d68a645
commit e6455bae38
28 changed files with 2293 additions and 18 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

74
IS-Net/basics.py Normal file
View File

@ -0,0 +1,74 @@
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import glob
def mae_torch(pred,gt):
h,w = gt.shape[0:2]
sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4)
return maeError
def f1score_torch(pd,gt):
# print(gt.shape)
gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels
pp = pd[gt>128]
nn = pd[gt<=128]
pp_hist =torch.histc(pp,bins=255,min=0,max=255)
nn_hist = torch.histc(nn,bins=255,min=0,max=255)
pp_hist_flip = torch.flipud(pp_hist)
nn_hist_flip = torch.flipud(nn_hist)
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
recall = (pp_hist_flip_cum)/(gtNum + 1e-4)
f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4)
return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0]))
def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
import time
tic = time.time()
if(len(gt.shape)>2):
gt = gt[:,:,0]
pre, rec, f1 = f1score_torch(pred,gt)
mae = mae_torch(pred,gt)
# hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ###
if(hypar["valid_out_dir"]!=""):
if(not os.path.exists(hypar["valid_out_dir"])):
os.mkdir(hypar["valid_out_dir"])
dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx])
if(not os.path.exists(dataset_folder)):
os.mkdir(dataset_folder)
io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8))
print(valid_dataset.dataset["im_name"][idx]+".png")
print("time for evaluation : ", time.time()-tic)
return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy()

380
IS-Net/data_loader_cache.py Normal file
View File

@ -0,0 +1,380 @@
## data loader
## Ackownledgement:
## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
## for his helps in implementing cache machanism of our DIS dataloader.
from __future__ import print_function, division
import numpy as np
import random
from copy import deepcopy
import json
from tqdm import tqdm
from skimage import io
import os
from glob import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.transforms.functional import normalize
import torch.nn.functional as F
#### --------------------- DIS dataloader cache ---------------------####
def get_im_gt_name_dict(datasets, flag='valid'):
print("------------------------------", flag, "--------------------------------")
name_im_gt_list = []
for i in range(len(datasets)):
print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
tmp_im_list, tmp_gt_list = [], []
tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
# img_name_dict[im_dirs[i][0]] = tmp_im_list
print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
if(datasets[i]["gt_dir"]==""):
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
else:
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
if flag=="train": ## combine multiple training sets into one dataset
if len(name_im_gt_list)==0:
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
"im_path":tmp_im_list,
"gt_path":tmp_gt_list,
"im_ext":datasets[i]["im_ext"],
"gt_ext":datasets[i]["gt_ext"],
"cache_dir":datasets[i]["cache_dir"]})
else:
name_im_gt_list[0]["dataset_name"] = name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
name_im_gt_list[0]["im_path"] = name_im_gt_list[0]["im_path"] + tmp_im_list
name_im_gt_list[0]["gt_path"] = name_im_gt_list[0]["gt_path"] + tmp_gt_list
if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
exit()
name_im_gt_list[0]["im_ext"] = ".jpg"
name_im_gt_list[0]["gt_ext"] = ".png"
name_im_gt_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_list[0]["dataset_name"]
else: ## keep different validation or inference datasets as separate ones
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
"im_path":tmp_im_list,
"gt_path":tmp_gt_list,
"im_ext":datasets[i]["im_ext"],
"gt_ext":datasets[i]["gt_ext"],
"cache_dir":datasets[i]["cache_dir"]})
return name_im_gt_list
def create_dataloaders(name_im_gt_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False):
## model="train": return one dataloader for training
## model="valid": return a list of dataloaders for validation or testing
gos_dataloaders = []
gos_datasets = []
# if(mode=="train"):
if(len(name_im_gt_list)==0):
return gos_dataloaders, gos_datasets
num_workers_ = 1
if(batch_size>1):
num_workers_ = 2
if(batch_size>4):
num_workers_ = 4
if(batch_size>8):
num_workers_ = 8
for i in range(0,len(name_im_gt_list)):
gos_dataset = GOSDatasetCache([name_im_gt_list[i]],
cache_size = cache_size,
cache_path = name_im_gt_list[i]["cache_dir"],
cache_boost = cache_boost,
transform = transforms.Compose(my_transforms))
gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
gos_datasets.append(gos_dataset)
return gos_dataloaders, gos_datasets
def im_reader(im_path):
return io.imread(im_path)
def im_preprocess(im,size):
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
if im.shape[2] == 1:
im = np.repeat(im, 3, axis=2)
im_tensor = torch.tensor(im, dtype=torch.float32)
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
if(len(size)<2):
return im_tensor, im.shape[0:2]
else:
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = F.upsample(im_tensor, size, mode="bilinear")
im_tensor = torch.squeeze(im_tensor,0)
return im_tensor.type(torch.uint8), im.shape[0:2]
def gt_preprocess(gt,size):
if len(gt.shape) > 2:
gt = gt[:, :, 0]
gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
if(len(size)<2):
return gt_tensor.type(torch.uint8), gt.shape[0:2]
else:
gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
gt_tensor = torch.squeeze(gt_tensor,0)
return gt_tensor.type(torch.uint8), gt.shape[0:2]
# return gt_tensor, gt.shape[0:2]
class GOSRandomHFlip(object):
def __init__(self,prob=0.5):
self.prob = prob
def __call__(self,sample):
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
# random horizontal flip
if random.random() >= self.prob:
image = torch.flip(image,dims=[2])
label = torch.flip(label,dims=[2])
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
class GOSResize(object):
def __init__(self,size=[320,320]):
self.size = size
def __call__(self,sample):
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
# import time
# start = time.time()
image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
# print("time for resize: ", time.time()-start)
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
class GOSRandomCrop(object):
def __init__(self,size=[288,288]):
self.size = size
def __call__(self,sample):
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
h, w = image.shape[1:]
new_h, new_w = self.size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[:,top:top+new_h,left:left+new_w]
label = label[:,top:top+new_h,left:left+new_w]
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
class GOSNormalize(object):
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
self.mean = mean
self.std = std
def __call__(self,sample):
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
image = normalize(image,self.mean,self.std)
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
class GOSDatasetCache(Dataset):
def __init__(self, name_im_gt_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None):
self.cache_size = cache_size
self.cache_path = cache_path
self.cache_file_name = cache_file_name
self.cache_boost_name = ""
self.cache_boost = cache_boost
# self.ims_npy = None
# self.gts_npy = None
## cache all the images and ground truth into a single pytorch tensor
self.ims_pt = None
self.gts_pt = None
## we will cache the npy as well regardless of the cache_boost
# if(self.cache_boost):
self.cache_boost_name = cache_file_name.split('.json')[0]
self.transform = transform
self.dataset = {}
## combine different datasets into one
dataset_names = []
dt_name_list = [] # dataset name per image
im_name_list = [] # image name
im_path_list = [] # im path
gt_path_list = [] # gt path
im_ext_list = [] # im ext
gt_ext_list = [] # gt ext
for i in range(0,len(name_im_gt_list)):
dataset_names.append(name_im_gt_list[i]["dataset_name"])
# dataset name repeated based on the number of images in this dataset
dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]])
im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]])
im_path_list.extend(name_im_gt_list[i]["im_path"])
gt_path_list.extend(name_im_gt_list[i]["gt_path"])
im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]])
gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]])
self.dataset["data_name"] = dt_name_list
self.dataset["im_name"] = im_name_list
self.dataset["im_path"] = im_path_list
self.dataset["ori_im_path"] = deepcopy(im_path_list)
self.dataset["gt_path"] = gt_path_list
self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
self.dataset["im_shp"] = []
self.dataset["gt_shp"] = []
self.dataset["im_ext"] = im_ext_list
self.dataset["gt_ext"] = gt_ext_list
self.dataset["ims_pt_dir"] = ""
self.dataset["gts_pt_dir"] = ""
self.dataset = self.manage_cache(dataset_names)
def manage_cache(self,dataset_names):
if not os.path.exists(self.cache_path): # create the folder for cache
os.mkdir(self.cache_path)
cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
return self.cache(cache_folder)
return self.load_cache(cache_folder)
def cache(self,cache_folder):
os.mkdir(cache_folder)
cached_dataset = deepcopy(self.dataset)
# ims_list = []
# gts_list = []
ims_pt_list = []
gts_pt_list = []
for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
im_id = cached_dataset["im_name"][i]
im = im_reader(im_path)
im, im_shp = im_preprocess(im,self.cache_size)
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
torch.save(im,im_cache_file)
cached_dataset["im_path"][i] = im_cache_file
if(self.cache_boost):
ims_pt_list.append(torch.unsqueeze(im,0))
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
gt = im_reader(self.dataset["gt_path"][i])
gt, gt_shp = gt_preprocess(gt,self.cache_size)
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
torch.save(gt,gt_cache_file)
cached_dataset["gt_path"][i] = gt_cache_file
if(self.cache_boost):
gts_pt_list.append(torch.unsqueeze(gt,0))
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
# im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
# torch.save(gt_shp, shp_cache_file)
cached_dataset["im_shp"].append(im_shp)
# self.dataset["im_shp"].append(im_shp)
# shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
# torch.save(gt_shp, shp_cache_file)
cached_dataset["gt_shp"].append(gt_shp)
# self.dataset["gt_shp"].append(gt_shp)
if(self.cache_boost):
cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
self.ims_pt = torch.cat(ims_pt_list,dim=0)
self.gts_pt = torch.cat(gts_pt_list,dim=0)
torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
try:
json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
json.dump(cached_dataset, json_file)
json_file.close()
except Exception:
raise FileNotFoundError("Cannot create JSON")
return cached_dataset
def load_cache(self, cache_folder):
json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
dataset = json.load(json_file)
json_file.close()
## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
## otherwise the pytorch tensor will be loaded
if(self.cache_boost):
# self.ims_npy = np.load(dataset["ims_npy_dir"])
# self.gts_npy = np.load(dataset["gts_npy_dir"])
self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
return dataset
def __len__(self):
return len(self.dataset["im_path"])
def __getitem__(self, idx):
im = None
gt = None
if(self.cache_boost and self.ims_pt is not None):
# start = time.time()
im = self.ims_pt[idx]#.type(torch.float32)
gt = self.gts_pt[idx]#.type(torch.float32)
# print(idx, 'time for pt loading: ', time.time()-start)
else:
# import time
# start = time.time()
# print("tensor***")
im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
# print(idx,'time for tensor loading: ', time.time()-start)
im_shp = self.dataset["im_shp"][idx]
# print("time for loading im and gt: ", time.time()-start)
# start_time = time.time()
im = torch.divide(im,255.0)
gt = torch.divide(gt,255.0)
# print(idx, 'time for normalize torch divide: ', time.time()-start_time)
sample = {
"imidx": torch.from_numpy(np.array(idx)),
"image": im,
"label": gt,
"shape": torch.from_numpy(np.array(im_shp)),
}
if self.transform:
sample = self.transform(sample)
return sample

188
IS-Net/hce_metric_main.py Normal file
View File

@ -0,0 +1,188 @@
## hce_metric.py
import numpy as np
from skimage import io
import matplotlib.pyplot as plt
import cv2 as cv
from skimage.morphology import skeletonize
from skimage.morphology import erosion, dilation, disk
from skimage.measure import label
import os
import sys
from tqdm import tqdm
from glob import glob
import pickle as pkl
def filter_bdy_cond(bdy_, mask, cond):
cond = cv.dilate(cond.astype(np.uint8),disk(1))
labels = label(mask) # find the connected regions
lbls = np.unique(labels) # the indices of the connected regions
indep = np.ones(lbls.shape[0]) # the label of each connected regions
indep[0] = 0 # 0 indicate the background region
boundaries = []
h,w = cond.shape[0:2]
ind_map = np.zeros((h,w))
indep_cnt = 0
for i in range(0,len(bdy_)):
tmp_bdies = []
tmp_bdy = []
for j in range(0,bdy_[i].shape[0]):
r, c = bdy_[i][j,0,1],bdy_[i][j,0,0]
if(np.sum(cond[r,c])==0 or ind_map[r,c]!=0):
if(len(tmp_bdy)>0):
tmp_bdies.append(tmp_bdy)
tmp_bdy = []
continue
tmp_bdy.append([c,r])
ind_map[r,c] = ind_map[r,c] + 1
indep[labels[r,c]] = 0 # indicates part of the boundary of this region needs human correction
if(len(tmp_bdy)>0):
tmp_bdies.append(tmp_bdy)
# check if the first and the last boundaries are connected
# if yes, invert the first boundary and attach it after the last boundary
if(len(tmp_bdies)>1):
first_x, first_y = tmp_bdies[0][0]
last_x, last_y = tmp_bdies[-1][-1]
if((abs(first_x-last_x)==1 and first_y==last_y) or
(first_x==last_x and abs(first_y-last_y)==1) or
(abs(first_x-last_x)==1 and abs(first_y-last_y)==1)
):
tmp_bdies[-1].extend(tmp_bdies[0][::-1])
del tmp_bdies[0]
for k in range(0,len(tmp_bdies)):
tmp_bdies[k] = np.array(tmp_bdies[k])[:,np.newaxis,:]
if(len(tmp_bdies)>0):
boundaries.extend(tmp_bdies)
return boundaries, np.sum(indep)
# this function approximate each boundary by DP algorithm
# https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm
def approximate_RDP(boundaries,epsilon=1.0):
boundaries_ = []
boundaries_len_ = []
pixel_cnt_ = 0
# polygon approximate of each boundary
for i in range(0,len(boundaries)):
boundaries_.append(cv.approxPolyDP(boundaries[i],epsilon,False))
# count the control points number of each boundary and the total control points number of all the boundaries
for i in range(0,len(boundaries_)):
boundaries_len_.append(len(boundaries_[i]))
pixel_cnt_ = pixel_cnt_ + len(boundaries_[i])
return boundaries_, boundaries_len_, pixel_cnt_
def relax_HCE(gt, rs, gt_ske, relax=5, epsilon=2.0):
# print("max(gt_ske): ", np.amax(gt_ske))
# gt_ske = gt_ske>128
# print("max(gt_ske): ", np.amax(gt_ske))
# Binarize gt
if(len(gt.shape)>2):
gt = gt[:,:,0]
epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0
gt = (gt>epsilon_gt).astype(np.uint8)
# Binarize rs
if(len(rs.shape)>2):
rs = rs[:,:,0]
epsilon_rs = 128#(np.amin(rs)+np.amax(rs))/2.0
rs = (rs>epsilon_rs).astype(np.uint8)
Union = np.logical_or(gt,rs)
TP = np.logical_and(gt,rs)
FP = rs - TP
FN = gt - TP
# relax the Union of gt and rs
Union_erode = Union.copy()
Union_erode = cv.erode(Union_erode.astype(np.uint8),disk(1),iterations=relax)
# --- get the relaxed False Positive regions for computing the human efforts in correcting them ---
FP_ = np.logical_and(FP,Union_erode) # get the relaxed FP
for i in range(0,relax):
FP_ = cv.dilate(FP_.astype(np.uint8),disk(1))
FP_ = np.logical_and(FP_, 1-np.logical_or(TP,FN))
FP_ = np.logical_and(FP, FP_)
# --- get the relaxed False Negative regions for computing the human efforts in correcting them ---
FN_ = np.logical_and(FN,Union_erode) # preserve the structural components of FN
## recover the FN, where pixels are not close to the TP borders
for i in range(0,relax):
FN_ = cv.dilate(FN_.astype(np.uint8),disk(1))
FN_ = np.logical_and(FN_,1-np.logical_or(TP,FP))
FN_ = np.logical_and(FN,FN_)
FN_ = np.logical_or(FN_, np.logical_xor(gt_ske,np.logical_and(TP,gt_ske))) # preserve the structural components of FN
## 2. =============Find exact polygon control points and independent regions==============
## find contours from FP_
ctrs_FP, hier_FP = cv.findContours(FP_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
## find control points and independent regions for human correction
bdies_FP, indep_cnt_FP = filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_))
## find contours from FN_
ctrs_FN, hier_FN = cv.findContours(FN_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
## find control points and independent regions for human correction
bdies_FN, indep_cnt_FN = filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP,FP_),FN_))
poly_FP, poly_FP_len, poly_FP_point_cnt = approximate_RDP(bdies_FP,epsilon=epsilon)
poly_FN, poly_FN_len, poly_FN_point_cnt = approximate_RDP(bdies_FN,epsilon=epsilon)
return poly_FP_point_cnt, indep_cnt_FP, poly_FN_point_cnt, indep_cnt_FN
def compute_hce(pred_root,gt_root,gt_ske_root):
gt_name_list = glob(pred_root+'/*.png')
gt_name_list = sorted([x.split('/')[-1] for x in gt_name_list])
hces = []
for gt_name in tqdm(gt_name_list, total=len(gt_name_list)):
gt_path = os.path.join(gt_root, gt_name)
pred_path = os.path.join(pred_root, gt_name)
gt = cv.imread(gt_path, cv.IMREAD_GRAYSCALE)
pred = cv.imread(pred_path, cv.IMREAD_GRAYSCALE)
ske_path = os.path.join(gt_ske_root,gt_name)
if os.path.exists(ske_path):
ske = cv.imread(ske_path,cv.IMREAD_GRAYSCALE)
ske = ske>128
else:
ske = skeletonize(gt>128)
FP_points, FP_indep, FN_points, FN_indep = relax_HCE(gt, pred,ske)
print(gt_path.split('/')[-1],FP_points, FP_indep, FN_points, FN_indep)
hces.append([FP_points, FP_indep, FN_points, FN_indep, FP_points+FP_indep+FN_points+FN_indep])
hce_metric ={'names': gt_name_list,
'hces': hces}
file_metric = open(pred_root+'/hce_metric.pkl','wb')
pkl.dump(hce_metric,file_metric)
# file_metrics.write(cmn_metrics)
file_metric.close()
return np.mean(np.array(hces)[:,-1])
def main():
gt_root = "../DIS5K/DIS-VD/gt"
gt_ske_root = ""
pred_root = "../Results/isnet(ours)/DIS-VD"
print("The average HCE metric: ", compute_hce(pred_root,gt_root,gt_ske_root))
if __name__ == '__main__':
main()

View File

@ -0,0 +1 @@
from models.isnet import ISNetGTEncoder, ISNetDIS

Binary file not shown.

Binary file not shown.

614
IS-Net/models/isnet.py Normal file
View File

@ -0,0 +1,614 @@
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
bce_loss = nn.BCELoss(size_average=True)
def muti_loss_fusion(preds, target):
loss0 = 0.0
loss = 0.0
for i in range(0,len(preds)):
# print("i: ", i, preds[i].shape)
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
# tmp_target = _upsample_like(target,preds[i])
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
loss = loss + bce_loss(preds[i],tmp_target)
else:
loss = loss + bce_loss(preds[i],target)
if(i==0):
loss0 = loss
return loss0, loss
fea_loss = nn.MSELoss(size_average=True)
kl_loss = nn.KLDivLoss(size_average=True)
l1_loss = nn.L1Loss(size_average=True)
smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
loss0 = 0.0
loss = 0.0
for i in range(0,len(preds)):
# print("i: ", i, preds[i].shape)
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
# tmp_target = _upsample_like(target,preds[i])
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
loss = loss + bce_loss(preds[i],tmp_target)
else:
loss = loss + bce_loss(preds[i],target)
if(i==0):
loss0 = loss
for i in range(0,len(dfs)):
if(mode=='MSE'):
loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
# print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
elif(mode=='KL'):
loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
# print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
elif(mode=='MAE'):
loss = loss + l1_loss(dfs[i],fs[i])
# print("ls_loss: ", l1_loss(dfs[i],fs[i]))
elif(mode=='SmoothL1'):
loss = loss + smooth_l1_loss(dfs[i],fs[i])
# print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
return loss0, loss
class REBNCONV(nn.Module):
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
super(REBNCONV,self).__init__()
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self,x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
return src
### RSU-7 ###
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7,self).__init__()
self.in_ch = in_ch
self.mid_ch = mid_ch
self.out_ch = out_ch
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
b, c, h, w = x.shape
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6dup = _upsample_like(hx6d,hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
return hx1d + hxin
class myrebnconv(nn.Module):
def __init__(self, in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1):
super(myrebnconv,self).__init__()
self.conv = nn.Conv2d(in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
self.bn = nn.BatchNorm2d(out_ch)
self.rl = nn.ReLU(inplace=True)
def forward(self,x):
return self.rl(self.bn(self.conv(x)))
class ISNetGTEncoder(nn.Module):
def __init__(self,in_ch=1,out_ch=1):
super(ISNetGTEncoder,self).__init__()
self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
self.stage1 = RSU7(16,16,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,16,64)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(64,32,128)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(128,32,256)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(256,64,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,64,512)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
def compute_loss_max(self, preds, targets, fs):
return muti_loss_fusion_max(preds, targets,fs)
def compute_loss(self, preds, targets):
return muti_loss_fusion(preds,targets)
def forward(self,x):
hx = x
hxin = self.conv_in(hx)
# hx = self.pool_in(hxin)
#stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
#side output
d1 = self.side1(hx1)
d1 = _upsample_like(d1,x)
d2 = self.side2(hx2)
d2 = _upsample_like(d2,x)
d3 = self.side3(hx3)
d3 = _upsample_like(d3,x)
d4 = self.side4(hx4)
d4 = _upsample_like(d4,x)
d5 = self.side5(hx5)
d5 = _upsample_like(d5,x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,x)
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
class ISNetDIS(nn.Module):
def __init__(self,in_ch=3,out_ch=1):
super(ISNetDIS,self).__init__()
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage1 = RSU7(64,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,256,512)
# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
# return muti_loss_fusion(preds,targets)
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
def compute_loss(self, preds, targets):
# return muti_loss_fusion(preds,targets)
return muti_loss_fusion(preds, targets)
def forward(self,x):
hx = x
hxin = self.conv_in(hx)
hx = self.pool_in(hxin)
#stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1,x)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,x)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,x)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,x)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,x)
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]

92
IS-Net/pytorch18.yml Normal file
View File

@ -0,0 +1,92 @@
name: pytorch18
channels:
- conda-forge
- anaconda
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- blas=1.0=mkl
- brotli=1.0.9=he6710b0_2
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.2.1=h06a4308_0
- certifi=2021.10.8=py37h06a4308_2
- cloudpickle=2.0.0=pyhd3eb1b0_0
- colorama=0.4.4=pyhd3eb1b0_0
- cudatoolkit=10.2.89=hfd86e86_1
- cycler=0.11.0=pyhd3eb1b0_0
- cytoolz=0.11.0=py37h7b6447c_0
- dask-core=2021.10.0=pyhd3eb1b0_0
- ffmpeg=4.3=hf484d3e_0
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.11.0=h70c0345_0
- fsspec=2022.2.0=pyhd3eb1b0_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- imageio=2.9.0=pyhd3eb1b0_0
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9b=h024ee3a_2
- kiwisolver=1.3.2=py37h295c915_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgfortran-ng=7.5.0=ha8ba4b0_17
- libgfortran4=7.5.0=ha8ba4b0_17
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.2=h7f8727e_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h85742a9_0
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- libwebp-base=1.2.2=h7f8727e_0
- locket=0.2.1=py37h06a4308_2
- lz4-c=1.9.3=h295c915_1
- matplotlib-base=3.5.1=py37ha18d171_1
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py37h7f8727e_0
- mkl_fft=1.3.1=py37hd3c417c_0
- mkl_random=1.2.2=py37h51133e4_0
- munkres=1.1.4=py_0
- ncurses=6.3=h7f8727e_2
- nettle=3.7.3=hbbd107a_1
- networkx=2.6.3=pyhd3eb1b0_0
- ninja=1.10.2=py37hd09550d_3
- numpy=1.21.2=py37h20f2e39_0
- numpy-base=1.21.2=py37h79a1101_0
- olefile=0.46=py37_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1n=h7f8727e_0
- packaging=21.3=pyhd3eb1b0_0
- partd=1.2.0=pyhd3eb1b0_1
- pillow=8.0.0=py37h9a89aac_0
- pip=21.2.2=py37h06a4308_0
- pyparsing=3.0.4=pyhd3eb1b0_0
- python=3.7.11=h12debd9_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
- pywavelets=1.1.1=py37h7b6447c_2
- pyyaml=6.0=py37h7f8727e_1
- readline=8.1.2=h7f8727e_1
- scikit-image=0.15.0=py37hb3f55d8_2
- scipy=1.7.3=py37hc147768_0
- setuptools=58.0.4=py37h06a4308_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.38.0=hc218d9a_0
- tk=8.6.11=h1ccaba5_0
- toolz=0.11.2=pyhd3eb1b0_0
- torchaudio=0.8.0=py37
- torchvision=0.9.0=py37_cu102
- tqdm=4.63.0=pyhd8ed1ab_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h7b6447c_0
- zlib=1.2.11=h7f8727e_4
- zstd=1.4.9=haebb681_0
prefix: /home/solar/anaconda3/envs/pytorch18

87
IS-Net/requirements.txt Normal file
View File

@ -0,0 +1,87 @@
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
_libgcc_mutex=0.1=main
_openmp_mutex=4.5=1_gnu
blas=1.0=mkl
brotli=1.0.9=he6710b0_2
bzip2=1.0.8=h7b6447c_0
ca-certificates=2022.2.1=h06a4308_0
certifi=2021.10.8=py37h06a4308_2
cloudpickle=2.0.0=pyhd3eb1b0_0
colorama=0.4.4=pyhd3eb1b0_0
cudatoolkit=10.2.89=hfd86e86_1
cycler=0.11.0=pyhd3eb1b0_0
cytoolz=0.11.0=py37h7b6447c_0
dask-core=2021.10.0=pyhd3eb1b0_0
ffmpeg=4.3=hf484d3e_0
fonttools=4.25.0=pyhd3eb1b0_0
freetype=2.11.0=h70c0345_0
fsspec=2022.2.0=pyhd3eb1b0_0
gmp=6.2.1=h2531618_2
gnutls=3.6.15=he1e5248_0
imageio=2.9.0=pyhd3eb1b0_0
intel-openmp=2021.4.0=h06a4308_3561
jpeg=9b=h024ee3a_2
kiwisolver=1.3.2=py37h295c915_0
lame=3.100=h7b6447c_0
lcms2=2.12=h3be6417_0
ld_impl_linux-64=2.35.1=h7274673_9
libffi=3.3=he6710b0_2
libgcc-ng=9.3.0=h5101ec6_17
libgfortran-ng=7.5.0=ha8ba4b0_17
libgfortran4=7.5.0=ha8ba4b0_17
libgomp=9.3.0=h5101ec6_17
libiconv=1.15=h63c8f33_5
libidn2=2.3.2=h7f8727e_0
libpng=1.6.37=hbc83047_0
libstdcxx-ng=9.3.0=hd4cf53a_17
libtasn1=4.16.0=h27cfd23_0
libtiff=4.2.0=h85742a9_0
libunistring=0.9.10=h27cfd23_0
libuv=1.40.0=h7b6447c_0
libwebp-base=1.2.2=h7f8727e_0
locket=0.2.1=py37h06a4308_2
lz4-c=1.9.3=h295c915_1
matplotlib-base=3.5.1=py37ha18d171_1
mkl=2021.4.0=h06a4308_640
mkl-service=2.4.0=py37h7f8727e_0
mkl_fft=1.3.1=py37hd3c417c_0
mkl_random=1.2.2=py37h51133e4_0
munkres=1.1.4=py_0
ncurses=6.3=h7f8727e_2
nettle=3.7.3=hbbd107a_1
networkx=2.6.3=pyhd3eb1b0_0
ninja=1.10.2=py37hd09550d_3
numpy=1.21.2=py37h20f2e39_0
numpy-base=1.21.2=py37h79a1101_0
olefile=0.46=py37_0
openh264=2.1.1=h4ff587b_0
openssl=1.1.1n=h7f8727e_0
packaging=21.3=pyhd3eb1b0_0
partd=1.2.0=pyhd3eb1b0_1
pillow=8.0.0=py37h9a89aac_0
pip=21.2.2=py37h06a4308_0
pyparsing=3.0.4=pyhd3eb1b0_0
python=3.7.11=h12debd9_0
python-dateutil=2.8.2=pyhd3eb1b0_0
pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
pywavelets=1.1.1=py37h7b6447c_2
pyyaml=6.0=py37h7f8727e_1
readline=8.1.2=h7f8727e_1
scikit-image=0.15.0=py37hb3f55d8_2
scipy=1.7.3=py37hc147768_0
setuptools=58.0.4=py37h06a4308_0
six=1.16.0=pyhd3eb1b0_1
sqlite=3.38.0=hc218d9a_0
tk=8.6.11=h1ccaba5_0
toolz=0.11.2=pyhd3eb1b0_0
torchaudio=0.8.0=py37
torchvision=0.9.0=py37_cu102
tqdm=4.63.0=pyhd8ed1ab_0
typing_extensions=3.10.0.2=pyh06a4308_0
wheel=0.37.1=pyhd3eb1b0_0
xz=5.2.5=h7b6447c_0
yaml=0.2.5=h7b6447c_0
zlib=1.2.11=h7f8727e_4
zstd=1.4.9=haebb681_0

View File

@ -0,0 +1,713 @@
import os
import time
import numpy as np
from skimage import io
import time
import torch, gc
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
from models import *
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
# cache_size = hypar["cache_size"],
# cache_boost = hypar["cache_boost_train"],
# my_transforms = [
# GOSRandomHFlip(),
# # GOSResize(hypar["input_size"]),
# # GOSRandomCrop(hypar["crop_size"]),
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
# ],
# batch_size = hypar["batch_size_train"],
# shuffle = True)
torch.manual_seed(hypar["seed"])
if torch.cuda.is_available():
torch.cuda.manual_seed(hypar["seed"])
print("define gt encoder ...")
net = ISNetGTEncoder() #UNETGTENCODERCombine()
## load the existing model gt encoder
if(hypar["gt_encoder_model"]!=""):
model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"]
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
print("gt encoder restored from the saved weights ...")
return net ############
if torch.cuda.is_available():
net.cuda()
print("--- define optimizer for GT Encoder---")
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
model_path = hypar["model_path"]
model_save_fre = hypar["model_save_fre"]
max_ite = hypar["max_ite"]
batch_size_train = hypar["batch_size_train"]
batch_size_valid = hypar["batch_size_valid"]
if(not os.path.exists(model_path)):
os.mkdir(model_path)
ite_num = hypar["start_ite"] # count the total iteration number
ite_num4val = 0 #
running_loss = 0.0 # count the toal loss
running_tar_loss = 0.0 # count the target output loss
last_f1 = [0 for x in range(len(valid_dataloaders))]
train_num = train_datasets[0].__len__()
net.train()
start_last = time.time()
gos_dataloader = train_dataloaders[0]
epoch_num = hypar["max_epoch_num"]
notgood_cnt = 0
for epoch in range(epoch_num): ## set the epoch num as 100000
for i, data in enumerate(gos_dataloader):
if(ite_num >= max_ite):
print("Training Reached the Maximal Iteration Number ", max_ite)
exit()
# start_read = time.time()
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
# get the inputs
labels = data['label']
if(hypar["model_digit"]=="full"):
labels = labels.type(torch.FloatTensor)
else:
labels = labels.type(torch.HalfTensor)
# wrap them in Variable
if torch.cuda.is_available():
labels_v = Variable(labels.cuda(), requires_grad=False)
else:
labels_v = Variable(labels, requires_grad=False)
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
# y zero the parameter gradients
start_inf_loss_back = time.time()
optimizer.zero_grad()
ds, fs = net(labels_v)#net(inputs_v)
loss2, loss = net.compute_loss(ds, labels_v)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_tar_loss += loss2.item()
# del outputs, loss
del ds, loss2, loss
end_inf_loss_back = time.time()-start_inf_loss_back
print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
start_last = time.time()
if ite_num % model_save_fre == 0: # validate every 2000 iterations
notgood_cnt += 1
# net.eval()
# tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch)
net.train() # resume train
tmp_out = 0
print("last_f1:",last_f1)
print("tmp_f1:",tmp_f1)
for fi in range(len(last_f1)):
if(tmp_f1[fi]>last_f1[fi]):
tmp_out = 1
print("tmp_out:",tmp_out)
if(tmp_out):
notgood_cnt = 0
last_f1 = tmp_f1
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
maxf1 = '_'.join(tmp_f1_str)
meanM = '_'.join(tmp_mae_str)
# .cpu().detach().numpy()
model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
"_maxF1_" + maxf1 + \
"_mae_" + meanM + \
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
torch.save(net.state_dict(), model_path + model_name)
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
if(tmp_f1[0]>0.99):
print("GT encoder is well-trained and obtained...")
return net
if(notgood_cnt >= hypar["early_stop"]):
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
exit()
print("Training Reaches The Maximum Epoch Number")
return net
def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
net.eval()
print("Validating...")
epoch_num = hypar["max_epoch_num"]
val_loss = 0.0
tar_loss = 0.0
tmp_f1 = []
tmp_mae = []
tmp_time = []
start_valid = time.time()
for k in range(len(valid_dataloaders)):
valid_dataloader = valid_dataloaders[k]
valid_dataset = valid_datasets[k]
val_num = valid_dataset.__len__()
mybins = np.arange(0,256)
PRE = np.zeros((val_num,len(mybins)-1))
REC = np.zeros((val_num,len(mybins)-1))
F1 = np.zeros((val_num,len(mybins)-1))
MAE = np.zeros((val_num))
val_cnt = 0.0
for i_val, data_val in enumerate(valid_dataloader):
# imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape']
if(hypar["model_digit"]=="full"):
labels_val = labels_val.type(torch.FloatTensor)
else:
labels_val = labels_val.type(torch.HalfTensor)
# wrap them in Variable
if torch.cuda.is_available():
labels_val_v = Variable(labels_val.cuda(), requires_grad=False)
else:
labels_val_v = Variable(labels_val,requires_grad=False)
t_start = time.time()
ds_val = net(labels_val_v)[0]
t_end = time.time()-t_start
tmp_time.append(t_end)
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
# compute F measure
for t in range(hypar["batch_size_valid"]):
val_cnt = val_cnt + 1.0
print("num of val: ", val_cnt)
i_test = imidx_val[t].data.numpy()
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
## recover the prediction spatial size to the orignal image size
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val-mi)/(ma-mi) # max = 1
# pred_val = normPRED(pred_val)
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
with torch.no_grad():
gt = torch.tensor(gt).cuda()
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
PRE[i_test,:]=pre
REC[i_test,:] = rec
F1[i_test,:] = f1
MAE[i_test] = mae
del ds_val, gt
gc.collect()
torch.cuda.empty_cache()
# if(loss_val.data[0]>1):
val_loss += loss_val.item()#data[0]
tar_loss += loss2_val.item()#data[0]
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
del loss2_val, loss_val
print('============================')
PRE_m = np.mean(PRE,0)
REC_m = np.mean(REC,0)
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
# print('--------------:', np.mean(f1_m))
tmp_f1.append(np.amax(f1_m))
tmp_mae.append(np.mean(MAE))
print("The max F1 Score: %f"%(np.max(f1_m)))
print("MAE: ", np.mean(MAE))
# print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
if hypar["interm_sup"]:
print("Get the gt encoder ...")
featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val)
## freeze the weights of gt encoder
for param in featurenet.parameters():
param.requires_grad=False
model_path = hypar["model_path"]
model_save_fre = hypar["model_save_fre"]
max_ite = hypar["max_ite"]
batch_size_train = hypar["batch_size_train"]
batch_size_valid = hypar["batch_size_valid"]
if(not os.path.exists(model_path)):
os.mkdir(model_path)
ite_num = hypar["start_ite"] # count the toal iteration number
ite_num4val = 0 #
running_loss = 0.0 # count the toal loss
running_tar_loss = 0.0 # count the target output loss
last_f1 = [0 for x in range(len(valid_dataloaders))]
train_num = train_datasets[0].__len__()
net.train()
start_last = time.time()
gos_dataloader = train_dataloaders[0]
epoch_num = hypar["max_epoch_num"]
notgood_cnt = 0
for epoch in range(epoch_num): ## set the epoch num as 100000
for i, data in enumerate(gos_dataloader):
if(ite_num >= max_ite):
print("Training Reached the Maximal Iteration Number ", max_ite)
exit()
# start_read = time.time()
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
# get the inputs
inputs, labels = data['image'], data['label']
if(hypar["model_digit"]=="full"):
inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)
else:
inputs = inputs.type(torch.HalfTensor)
labels = labels.type(torch.HalfTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
# y zero the parameter gradients
start_inf_loss_back = time.time()
optimizer.zero_grad()
if hypar["interm_sup"]:
# forward + backward + optimize
ds,dfs = net(inputs_v)
_,fs = featurenet(labels_v) ## extract the gt encodings
loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE')
else:
# forward + backward + optimize
ds,_ = net(inputs_v)
loss2, loss = muti_loss_fusion(ds, labels_v)
loss.backward()
optimizer.step()
# # print statistics
running_loss += loss.item()
running_tar_loss += loss2.item()
# del outputs, loss
del ds, loss2, loss
end_inf_loss_back = time.time()-start_inf_loss_back
print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
start_last = time.time()
if ite_num % model_save_fre == 0: # validate every 2000 iterations
notgood_cnt += 1
net.eval()
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch)
net.train() # resume train
tmp_out = 0
print("last_f1:",last_f1)
print("tmp_f1:",tmp_f1)
for fi in range(len(last_f1)):
if(tmp_f1[fi]>last_f1[fi]):
tmp_out = 1
print("tmp_out:",tmp_out)
if(tmp_out):
notgood_cnt = 0
last_f1 = tmp_f1
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
maxf1 = '_'.join(tmp_f1_str)
meanM = '_'.join(tmp_mae_str)
# .cpu().detach().numpy()
model_name = "/gpu_itr_"+str(ite_num)+\
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
"_maxF1_" + maxf1 + \
"_mae_" + meanM + \
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
torch.save(net.state_dict(), model_path + model_name)
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
if(notgood_cnt >= hypar["early_stop"]):
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
exit()
print("Training Reaches The Maximum Epoch Number")
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
net.eval()
print("Validating...")
epoch_num = hypar["max_epoch_num"]
val_loss = 0.0
tar_loss = 0.0
val_cnt = 0.0
tmp_f1 = []
tmp_mae = []
tmp_time = []
start_valid = time.time()
for k in range(len(valid_dataloaders)):
valid_dataloader = valid_dataloaders[k]
valid_dataset = valid_datasets[k]
val_num = valid_dataset.__len__()
mybins = np.arange(0,256)
PRE = np.zeros((val_num,len(mybins)-1))
REC = np.zeros((val_num,len(mybins)-1))
F1 = np.zeros((val_num,len(mybins)-1))
MAE = np.zeros((val_num))
for i_val, data_val in enumerate(valid_dataloader):
val_cnt = val_cnt + 1.0
imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
if(hypar["model_digit"]=="full"):
inputs_val = inputs_val.type(torch.FloatTensor)
labels_val = labels_val.type(torch.FloatTensor)
else:
inputs_val = inputs_val.type(torch.HalfTensor)
labels_val = labels_val.type(torch.HalfTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False)
else:
inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False)
t_start = time.time()
ds_val = net(inputs_val_v)[0]
t_end = time.time()-t_start
tmp_time.append(t_end)
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
# compute F measure
for t in range(hypar["batch_size_valid"]):
i_test = imidx_val[t].data.numpy()
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
## recover the prediction spatial size to the orignal image size
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
# pred_val = normPRED(pred_val)
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val-mi)/(ma-mi) # max = 1
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
with torch.no_grad():
gt = torch.tensor(gt).cuda()
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
PRE[i_test,:]=pre
REC[i_test,:] = rec
F1[i_test,:] = f1
MAE[i_test] = mae
del ds_val, gt
gc.collect()
torch.cuda.empty_cache()
# if(loss_val.data[0]>1):
val_loss += loss_val.item()#data[0]
tar_loss += loss2_val.item()#data[0]
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
del loss2_val, loss_val
print('============================')
PRE_m = np.mean(PRE,0)
REC_m = np.mean(REC,0)
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
tmp_f1.append(np.amax(f1_m))
tmp_mae.append(np.mean(MAE))
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
def main(train_datasets,
valid_datasets,
hypar): # model: "train", "test"
### --- Step 1: Build datasets and dataloaders ---
dataloaders_train = []
dataloaders_valid = []
if(hypar["mode"]=="train"):
print("--- create training dataloader ---")
## collect training dataset
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
## build dataloader for training datasets
train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
cache_size = hypar["cache_size"],
cache_boost = hypar["cache_boost_train"],
my_transforms = [
GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
# GOSResize(hypar["input_size"]),
# GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
],
batch_size = hypar["batch_size_train"],
shuffle = True)
train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list,
cache_size = hypar["cache_size"],
cache_boost = hypar["cache_boost_train"],
my_transforms = [
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
],
batch_size = hypar["batch_size_valid"],
shuffle = False)
print(len(train_dataloaders), " train dataloaders created")
print("--- create valid dataloader ---")
## build dataloader for validation or testing
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
## build dataloader for training datasets
valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list,
cache_size = hypar["cache_size"],
cache_boost = hypar["cache_boost_valid"],
my_transforms = [
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
# GOSResize(hypar["input_size"])
],
batch_size=hypar["batch_size_valid"],
shuffle=False)
print(len(valid_dataloaders), " valid dataloaders created")
# print(valid_datasets[0]["data_name"])
### --- Step 2: Build Model and Optimizer ---
print("--- build model ---")
net = hypar["model"]#GOSNETINC(3,1)
# convert to half precision
if(hypar["model_digit"]=="half"):
net.half()
for layer in net.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
if torch.cuda.is_available():
net.cuda()
if(hypar["restore_model"]!=""):
print("restore model from:")
print(hypar["model_path"]+"/"+hypar["restore_model"])
if torch.cuda.is_available():
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]))
else:
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"))
print("--- define optimizer ---")
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
### --- Step 3: Train or Valid Model ---
if(hypar["mode"]=="train"):
train(net,
optimizer,
train_dataloaders,
train_datasets,
valid_dataloaders,
valid_datasets,
hypar,
train_dataloaders_val, train_datasets_val)
else:
valid(net,
valid_dataloaders,
valid_datasets,
hypar)
if __name__ == "__main__":
### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
## configure the train, valid and inference datasets
train_datasets, valid_datasets = [], []
dataset_1, dataset_1 = {}, {}
dataset_tr = {"name": "DIS5K-TR",
"im_dir": "../DIS5K/DIS-TR/im",
"gt_dir": "../DIS5K/DIS-TR/gt",
"im_ext": ".jpg",
"gt_ext": ".png",
"cache_dir":"../DIS5K-Cache/DIS-TR"}
dataset_vd = {"name": "DIS5K-VD",
"im_dir": "../DIS5K/DIS-VD/im",
"gt_dir": "../DIS5K/DIS-VD/gt",
"im_ext": ".jpg",
"gt_ext": ".png",
"cache_dir":"../DIS5K-Cache/DIS-VD"}
dataset_te1 = {"name": "DIS5K-TE1",
"im_dir": "../DIS5K/DIS-TE1/im",
"gt_dir": "../DIS5K/DIS-TE1/gt",
"im_ext": ".jpg",
"gt_ext": ".png",
"cache_dir":"../DIS5K-Cache/DIS-TE1"}
dataset_te2 = {"name": "DIS5K-TE2",
"im_dir": "../DIS5K/DIS-TE2/im",
"gt_dir": "../DIS5K/DIS-TE2/gt",
"im_ext": ".jpg",
"gt_ext": ".png",
"cache_dir":"../DIS5K-Cache/DIS-TE2"}
dataset_te3 = {"name": "DIS5K-TE3",
"im_dir": "../DIS5K/DIS-TE3/im",
"gt_dir": "../DIS5K/DIS-TE3/gt",
"im_ext": ".jpg",
"gt_ext": ".png",
"cache_dir":"../DIS5K-Cache/DIS-TE3"}
dataset_te4 = {"name": "DIS5K-TE4",
"im_dir": "../DIS5K/DIS-TE4/im",
"gt_dir": "../DIS5K/DIS-TE4/gt",
"im_ext": ".jpg",
"gt_ext": ".png",
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
valid_datasets = [dataset_vd] #, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
hypar = {}
## -- 2.1. configure the model saving or restoring path --
hypar["mode"] = "train"
## "train": for training,
## "valid": for validation and inferening,
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
## otherwise only accuracy will be calculated and no predictions will be saved
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
if hypar["mode"] == "train":
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
hypar["gt_encoder_model"] = ""
else: ## configure the segmentation output path and the to-be-used model weights path
hypar["valid_out_dir"] = "../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
hypar["model_path"] ="../saved_models/IS-Net" ## load trained weights from this path
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
# if hypar["restore_model"]!="":
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
## -- 2.2. choose floating point accuracy --
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
hypar["seed"] = 0
## -- 2.3. cache data spatial size --
## To handle large size input images, which take a lot of time for loading in training,
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
## --- 2.4. data augmentation parameters ---
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the dataloader and it is not in use
hypar["random_flip_v"] = 0 ## vertical flip , currently not in use
## --- 2.5. define model ---
print("building model...")
hypar["model"] = ISNetDIS() #U2NETFASTFEATURESUP()
hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
hypar["model_save_fre"] = 2000 ## valid and save model weights every 2000 iterations
hypar["batch_size_train"] = 8 ## batch size for training
hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing
print("batch size: ", hypar["batch_size_train"])
hypar["max_ite"] = 10000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
hypar["max_epoch_num"] = 1000000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
main(train_datasets,
valid_datasets,
hypar=hypar)

161
README.md
View File

@ -2,27 +2,147 @@
<img width="420" height="320" src="figures/dis-logo-official.png">
</p>
![dis5k-v1-sailship](figures/dis5k-v1-sailship.jpeg)
<br>
## [Highly Accurate Dichotomous Image Segmentation ECCV 2022](https://arxiv.org/pdf/2203.03041.pdf)
[Xuebin Qin](https://xuebinqin.github.io/), [Hang Dai](https://scholar.google.co.uk/citations?user=6yvjpQQAAAAJ&hl=en), [Xiaobin Hu](https://scholar.google.de/citations?user=3lMuodUAAAAJ&hl=en), [Deng-Ping Fan*](https://dengpingfan.github.io/), [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en) and [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en).
[Xuebin Qin](https://xuebinqin.github.io/), [Hang Dai](https://scholar.google.co.uk/citations?user=6yvjpQQAAAAJ&hl=en), [Xiaobin Hu](https://scholar.google.de/citations?user=3lMuodUAAAAJ&hl=en), [Deng-Ping Fan*](https://dengpingfan.github.io/), [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en), [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en).
<br>
This is the official repo for our new project DIS:
[**[Project Page]**](https://xuebinqin.github.io/dis/index.html)
## This is the official repo for our newly formulated DIS task:
[**Project Page**](https://xuebinqin.github.io/dis/index.html), [**Arxiv**](https://arxiv.org/pdf/2203.03041.pdf).
<br>
## Updates !!!
<br>
## ** (2022-Jul.-17)** Our paper, code and dataset are now officially released!!! Please check our project page for more details: [**Project Page**](https://xuebinqin.github.io/dis/index.html).<br>
** (2022-Jul.-5)** Our DIS work is now accepted by ECCV 2022, the code and dataset will be released before July 17th, 2022. Please be aware of our updates.
## DEMO
![ship-demo](figures/ship-demo.gif)
![bg-removal](figures/bg-removal.gif)
![view-move](figures/view-move.gif)
![motor-demo](figures/motor-demo.gif)
<br>
## Comparisons Against SOTAs
## 1. [Our DIS5K Dataset V1.0 (Version Alias: DIS5K Sailship)](https://xuebinqin.github.io/dis/index.html)
<br>
### Download [Google Drive](https://drive.google.com/file/d/1jOC2zK0GowBvEt03B7dGugCRDVAoeIqq/view?usp=sharing) or [Baidu Pan 提取码rtgw](https://pan.baidu.com/s/1y6CQJYledfYyEO0C_Gejpw?pwd=rtgw)
![dis5k-dataset-v1-sailship](figures/DIS5k-dataset-v1-sailship.png)
![complexities-qual](figures/complexities-qual.jpeg)
![categories](figures/categories.jpeg)
<br>
## 2. APPLICATIONS of Our DIS5K Dataset
<br>
### 3D Modeling
![3d-modeling](figures/3d-modeling.png)
### Image Editing
![ship-demo](figures/ship-demo.gif)
### Art Design Materials
![bg-removal](figures/bg-removal.gif)
### Still Image Animation
![view-move](figures/view-move.gif)
### AR
![motor-demo](figures/motor-demo.gif)
### 3D Rendering
![video-3d](figures/video-3d.gif)
<br>
## 3. Architecture of Our IS-Net
<br>
![is-net](figures/is-net.png)
<br>
## 4. Human Correction Efforts (HCE)
<br>
![hce-metric](figures/hce-metric.png)
<br>
## 5. Experimental Results
<br>
### Predicted Maps, [(Google Drive)](https://drive.google.com/file/d/1FMtDLFrL6xVc41eKlLnuZWMBAErnKv0Y/view?usp=sharing), [(Baidu Pan 提取码ph1d)](https://pan.baidu.com/s/1WUk2RYYpii2xzrvLna9Fsg?pwd=ph1d), of Our IS-Net and Other SOTAs
### Qualitative Comparisons Against SOTAs
![qual-comp](figures/qual-comp.jpg)
### Quantitative Comparisons Against SOTAs
![qual-comp](figures/quan-comp.png)
<br>
## 6. Run Our Code
<br>
### (1) Clone this repo
```
git clone https://github.com/xuebinqin/DIS.git
```
### (2) Configuring the environment: go to the root ```DIS``` folder and run
```
conda env create -f pytorch18.yml
```
Or you can check the ```requirements.txt``` to configure the dependancies.
### (3) Train:
#### (a) Open ```train_valid_inference_main.py```, set the path of your to-be-inferenced ```train_datasets``` and ```valid_datasets```, e.g., ```valid_datasets=[dataset_vd]```
#### (b) Set the ```hypar["mode"]``` to ```"train"```
#### (c) Create a new folder ```your_model_weights``` in the directory ```saved_models``` and set it as the ```hypar["model_path"] ="../saved_models/your_model_weights"``` and make sure ```hypar["valid_out_dir"]```(line 668) is set to ```""```, otherwise the prediction maps of the validation stage will be saved to that directory, which will slow the training speed down
#### (d) Run
```
python train_valid_inference_main.py
```
### (4) Inference
#### (a). Download the pre-trained weights (for fair academic comparisons only, the optimized model for engineering or common use will be released soon) ```isnet.pth``` from [(Google Drive)](https://drive.google.com/file/d/1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn/view?usp=sharing) or [(Baidu Pan 提取码xbfk)](https://pan.baidu.com/s/1-X2WutiBkWPt-oakuvZ10w?pwd=xbfk) and store ```isnet.pth``` in ```saved_models/IS-Net```
#### (b) Open ```train_valid_inference_main.py```, set the path of your to-be-inferenced ```valid_datasets```, e.g., ```valid_datasets=[dataset_te1, dataset_te2, dataset_te3, dataset_te4]```
#### (c) Set the ```hypar["mode"]``` to ```"valid"```
#### (d) Set the output directory of your predicted maps, e.g., ```hypar["valid_out_dir"] = "../DIS5K-Results-test"```
#### (e) Run
```
python train_valid_inference_main.py
```
### (5) Use of our Human Correction Efforts(HCE) metric, set the ground truth directory ```gt_root``` and the prediction directory ```pred_root```. To reduce the time costs for computing HCE, the skeletion of the DIS5K dataset can be pre-computed and stored in ```gt_ske_root```. If ```gt_ske_root=""```, the HCE code will compute the skeleton online which usually takes a lot for time for large size ground truth. Then, run ```python hce_metric_main.py```. Other metrics are evaluated based on the [SOCToolbox](https://github.com/mczhuge/SOCToolbox).
<br>
## 7. Term of Use
Our code and evaluation metric use Apache License 2.0. The Terms of use for our DIS5K dataset is provided as [DIS5K-Dataset-Terms-of-Use.pdf](DIS5K-Dataset-Terms-of-Use.pdf).
<br>
## Acknowledgements
<br>
We would like to thank Dr. [Ibrahim Almakky](https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en) for his helps in implementing the dataloader cache machanism of loading large-size training samples and Jiayi Zhu for his efforts in re-organizing our code and dataset.
<br>
## Citation
<br>
```
@InProceedings{qin2022,
author={Xuebin Qin and Hang Dai and Xiaobin Hu and Deng-Ping Fan and Ling Shao and Luc Van Gool},
@ -31,8 +151,15 @@ This is the official repo for our new project DIS:
year={2022}
}
```
## Our Previous Works
<br>
## Our Previous Works: [U<sup>2</sup>-Net](https://github.com/xuebinqin/U-2-Net), [BASNet](https://github.com/xuebinqin/BASNet).
<br>
```
@InProceedings{Qin_2020_PR,
title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin},
@ -42,13 +169,6 @@ This is the official repo for our new project DIS:
year = {2020}
}
@article{qin2021boundary,
title={Boundary-aware segmentation network for mobile and web applications},
author={Qin, Xuebin and Fan, Deng-Ping and Huang, Chenyang and Diagne, Cyril and Zhang, Zichen and Sant'Anna, Adri{\`a} Cabeza and Suarez, Albert and Jagersand, Martin and Shao, Ling},
journal={arXiv preprint arXiv:2101.04704},
year={2021}
}
@InProceedings{Qin_2019_CVPR,
author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Gao, Chao and Dehghan, Masood and Jagersand, Martin},
title = {BASNet: Boundary-Aware Salient Object Detection},
@ -56,3 +176,10 @@ This is the official repo for our new project DIS:
month = {June},
year = {2019}
}
@article{qin2021boundary,
title={Boundary-aware segmentation network for mobile and web applications},
author={Qin, Xuebin and Fan, Deng-Ping and Huang, Chenyang and Diagne, Cyril and Zhang, Zichen and Sant'Anna, Adri{\`a} Cabeza and Suarez, Albert and Jagersand, Martin and Shao, Ling},
journal={arXiv preprint arXiv:2101.04704},
year={2021}
}

View File

@ -1 +0,0 @@

BIN
figures/.DS_Store vendored Normal file

Binary file not shown.

BIN
figures/3d-modeling.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

BIN
figures/categories.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 928 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 667 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 297 KiB

BIN
figures/hce-metric.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

BIN
figures/is-net.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 912 KiB

BIN
figures/quan-comp.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 887 KiB

BIN
figures/statistics.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 525 KiB

BIN
figures/video-3d.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 MiB

BIN
saved_models/.DS_Store vendored Normal file

Binary file not shown.