2022-07-17 07:56:37 +02:00
import os
import time
import numpy as np
from skimage import io
import time
import torch , gc
import torch . nn as nn
from torch . autograd import Variable
import torch . optim as optim
import torch . nn . functional as F
from data_loader_cache import get_im_gt_name_dict , create_dataloaders , GOSRandomHFlip , GOSResize , GOSRandomCrop , GOSNormalize #GOSDatasetCache,
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
from models import *
2022-07-17 11:22:08 +02:00
device = ' cuda ' if torch . cuda . is_available ( ) else ' cpu '
2022-07-17 07:56:37 +02:00
def get_gt_encoder ( train_dataloaders , train_datasets , valid_dataloaders , valid_datasets , hypar , train_dataloaders_val , train_datasets_val ) : #model_path, model_save_fre, max_ite=1000000):
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
# cache_size = hypar["cache_size"],
# cache_boost = hypar["cache_boost_train"],
# my_transforms = [
# GOSRandomHFlip(),
# # GOSResize(hypar["input_size"]),
# # GOSRandomCrop(hypar["crop_size"]),
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
# ],
# batch_size = hypar["batch_size_train"],
# shuffle = True)
torch . manual_seed ( hypar [ " seed " ] )
if torch . cuda . is_available ( ) :
torch . cuda . manual_seed ( hypar [ " seed " ] )
print ( " define gt encoder ... " )
net = ISNetGTEncoder ( ) #UNETGTENCODERCombine()
## load the existing model gt encoder
if ( hypar [ " gt_encoder_model " ] != " " ) :
model_path = hypar [ " model_path " ] + " / " + hypar [ " gt_encoder_model " ]
if torch . cuda . is_available ( ) :
net . load_state_dict ( torch . load ( model_path ) )
net . cuda ( )
else :
net . load_state_dict ( torch . load ( model_path , map_location = " cpu " ) )
print ( " gt encoder restored from the saved weights ... " )
return net ############
if torch . cuda . is_available ( ) :
net . cuda ( )
print ( " --- define optimizer for GT Encoder--- " )
optimizer = optim . Adam ( net . parameters ( ) , lr = 1e-3 , betas = ( 0.9 , 0.999 ) , eps = 1e-08 , weight_decay = 0 )
model_path = hypar [ " model_path " ]
model_save_fre = hypar [ " model_save_fre " ]
max_ite = hypar [ " max_ite " ]
batch_size_train = hypar [ " batch_size_train " ]
batch_size_valid = hypar [ " batch_size_valid " ]
if ( not os . path . exists ( model_path ) ) :
os . mkdir ( model_path )
ite_num = hypar [ " start_ite " ] # count the total iteration number
ite_num4val = 0 #
running_loss = 0.0 # count the toal loss
running_tar_loss = 0.0 # count the target output loss
last_f1 = [ 0 for x in range ( len ( valid_dataloaders ) ) ]
train_num = train_datasets [ 0 ] . __len__ ( )
net . train ( )
start_last = time . time ( )
gos_dataloader = train_dataloaders [ 0 ]
epoch_num = hypar [ " max_epoch_num " ]
notgood_cnt = 0
for epoch in range ( epoch_num ) : ## set the epoch num as 100000
for i , data in enumerate ( gos_dataloader ) :
if ( ite_num > = max_ite ) :
print ( " Training Reached the Maximal Iteration Number " , max_ite )
exit ( )
# start_read = time.time()
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
# get the inputs
labels = data [ ' label ' ]
if ( hypar [ " model_digit " ] == " full " ) :
labels = labels . type ( torch . FloatTensor )
else :
labels = labels . type ( torch . HalfTensor )
# wrap them in Variable
if torch . cuda . is_available ( ) :
labels_v = Variable ( labels . cuda ( ) , requires_grad = False )
else :
labels_v = Variable ( labels , requires_grad = False )
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
# y zero the parameter gradients
start_inf_loss_back = time . time ( )
optimizer . zero_grad ( )
ds , fs = net ( labels_v ) #net(inputs_v)
loss2 , loss = net . compute_loss ( ds , labels_v )
loss . backward ( )
optimizer . step ( )
running_loss + = loss . item ( )
running_tar_loss + = loss2 . item ( )
# del outputs, loss
del ds , loss2 , loss
end_inf_loss_back = time . time ( ) - start_inf_loss_back
print ( " GT Encoder Training>>> " + model_path . split ( ' / ' ) [ - 1 ] + " - [epoch: %3d / %3d , batch: %5d / %5d , ite: %d ] train loss: %3f , tar: %3f , time-per-iter: %3f s, time_read: %3f " % (
epoch + 1 , epoch_num , ( i + 1 ) * batch_size_train , train_num , ite_num , running_loss / ite_num4val , running_tar_loss / ite_num4val , time . time ( ) - start_last , time . time ( ) - start_last - end_inf_loss_back ) )
start_last = time . time ( )
if ite_num % model_save_fre == 0 : # validate every 2000 iterations
notgood_cnt + = 1
# net.eval()
# tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
tmp_f1 , tmp_mae , val_loss , tar_loss , i_val , tmp_time = valid_gt_encoder ( net , train_dataloaders_val , train_datasets_val , hypar , epoch )
net . train ( ) # resume train
tmp_out = 0
print ( " last_f1: " , last_f1 )
print ( " tmp_f1: " , tmp_f1 )
for fi in range ( len ( last_f1 ) ) :
if ( tmp_f1 [ fi ] > last_f1 [ fi ] ) :
tmp_out = 1
print ( " tmp_out: " , tmp_out )
if ( tmp_out ) :
notgood_cnt = 0
last_f1 = tmp_f1
tmp_f1_str = [ str ( round ( f1x , 4 ) ) for f1x in tmp_f1 ]
tmp_mae_str = [ str ( round ( mx , 4 ) ) for mx in tmp_mae ]
maxf1 = ' _ ' . join ( tmp_f1_str )
meanM = ' _ ' . join ( tmp_mae_str )
# .cpu().detach().numpy()
model_name = " /GTENCODER-gpu_itr_ " + str ( ite_num ) + \
" _traLoss_ " + str ( np . round ( running_loss / ite_num4val , 4 ) ) + \
" _traTarLoss_ " + str ( np . round ( running_tar_loss / ite_num4val , 4 ) ) + \
" _valLoss_ " + str ( np . round ( val_loss / ( i_val + 1 ) , 4 ) ) + \
" _valTarLoss_ " + str ( np . round ( tar_loss / ( i_val + 1 ) , 4 ) ) + \
" _maxF1_ " + maxf1 + \
" _mae_ " + meanM + \
" _time_ " + str ( np . round ( np . mean ( np . array ( tmp_time ) ) / batch_size_valid , 6 ) ) + " .pth "
torch . save ( net . state_dict ( ) , model_path + model_name )
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
if ( tmp_f1 [ 0 ] > 0.99 ) :
print ( " GT encoder is well-trained and obtained... " )
return net
if ( notgood_cnt > = hypar [ " early_stop " ] ) :
print ( " No improvements in the last " + str ( notgood_cnt ) + " validation periods, so training stopped ! " )
exit ( )
print ( " Training Reaches The Maximum Epoch Number " )
return net
def valid_gt_encoder ( net , valid_dataloaders , valid_datasets , hypar , epoch = 0 ) :
net . eval ( )
print ( " Validating... " )
epoch_num = hypar [ " max_epoch_num " ]
val_loss = 0.0
tar_loss = 0.0
tmp_f1 = [ ]
tmp_mae = [ ]
tmp_time = [ ]
start_valid = time . time ( )
for k in range ( len ( valid_dataloaders ) ) :
valid_dataloader = valid_dataloaders [ k ]
valid_dataset = valid_datasets [ k ]
val_num = valid_dataset . __len__ ( )
mybins = np . arange ( 0 , 256 )
PRE = np . zeros ( ( val_num , len ( mybins ) - 1 ) )
REC = np . zeros ( ( val_num , len ( mybins ) - 1 ) )
F1 = np . zeros ( ( val_num , len ( mybins ) - 1 ) )
MAE = np . zeros ( ( val_num ) )
val_cnt = 0.0
2022-07-19 04:34:51 +02:00
i_val = None
2022-07-17 07:56:37 +02:00
for i_val , data_val in enumerate ( valid_dataloader ) :
# imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
imidx_val , labels_val , shapes_val = data_val [ ' imidx ' ] , data_val [ ' label ' ] , data_val [ ' shape ' ]
if ( hypar [ " model_digit " ] == " full " ) :
labels_val = labels_val . type ( torch . FloatTensor )
else :
labels_val = labels_val . type ( torch . HalfTensor )
# wrap them in Variable
if torch . cuda . is_available ( ) :
labels_val_v = Variable ( labels_val . cuda ( ) , requires_grad = False )
else :
labels_val_v = Variable ( labels_val , requires_grad = False )
t_start = time . time ( )
ds_val = net ( labels_val_v ) [ 0 ]
t_end = time . time ( ) - t_start
tmp_time . append ( t_end )
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
loss2_val , loss_val = net . compute_loss ( ds_val , labels_val_v )
# compute F measure
for t in range ( hypar [ " batch_size_valid " ] ) :
val_cnt = val_cnt + 1.0
print ( " num of val: " , val_cnt )
i_test = imidx_val [ t ] . data . numpy ( )
pred_val = ds_val [ 0 ] [ t , : , : , : ] # B x 1 x H x W
## recover the prediction spatial size to the orignal image size
pred_val = torch . squeeze ( F . upsample ( torch . unsqueeze ( pred_val , 0 ) , ( shapes_val [ t ] [ 0 ] , shapes_val [ t ] [ 1 ] ) , mode = ' bilinear ' ) )
ma = torch . max ( pred_val )
mi = torch . min ( pred_val )
pred_val = ( pred_val - mi ) / ( ma - mi ) # max = 1
# pred_val = normPRED(pred_val)
gt = np . squeeze ( io . imread ( valid_dataset . dataset [ " ori_gt_path " ] [ i_test ] ) ) # max = 255
2022-11-04 09:10:53 +01:00
if gt . max ( ) == 1 :
gt = gt * 255
2022-07-17 07:56:37 +02:00
with torch . no_grad ( ) :
2022-07-17 11:22:08 +02:00
gt = torch . tensor ( gt ) . to ( device )
2022-07-17 07:56:37 +02:00
pre , rec , f1 , mae = f1_mae_torch ( pred_val * 255 , gt , valid_dataset , i_test , mybins , hypar )
PRE [ i_test , : ] = pre
REC [ i_test , : ] = rec
F1 [ i_test , : ] = f1
MAE [ i_test ] = mae
del ds_val , gt
gc . collect ( )
torch . cuda . empty_cache ( )
# if(loss_val.data[0]>1):
val_loss + = loss_val . item ( ) #data[0]
tar_loss + = loss2_val . item ( ) #data[0]
print ( " [validating: %5d / %5d ] val_ls: %f , tar_ls: %f , f1: %f , mae: %f , time: %f " % ( i_val , val_num , val_loss / ( i_val + 1 ) , tar_loss / ( i_val + 1 ) , np . amax ( F1 [ i_test , : ] ) , MAE [ i_test ] , t_end ) )
del loss2_val , loss_val
print ( ' ============================ ' )
PRE_m = np . mean ( PRE , 0 )
REC_m = np . mean ( REC , 0 )
f1_m = ( 1 + 0.3 ) * PRE_m * REC_m / ( 0.3 * PRE_m + REC_m + 1e-8 )
# print('--------------:', np.mean(f1_m))
tmp_f1 . append ( np . amax ( f1_m ) )
tmp_mae . append ( np . mean ( MAE ) )
print ( " The max F1 Score: %f " % ( np . max ( f1_m ) ) )
print ( " MAE: " , np . mean ( MAE ) )
# print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
return tmp_f1 , tmp_mae , val_loss , tar_loss , i_val , tmp_time
def train ( net , optimizer , train_dataloaders , train_datasets , valid_dataloaders , valid_datasets , hypar , train_dataloaders_val , train_datasets_val ) : #model_path, model_save_fre, max_ite=1000000):
if hypar [ " interm_sup " ] :
print ( " Get the gt encoder ... " )
featurenet = get_gt_encoder ( train_dataloaders , train_datasets , valid_dataloaders , valid_datasets , hypar , train_dataloaders_val , train_datasets_val )
## freeze the weights of gt encoder
for param in featurenet . parameters ( ) :
param . requires_grad = False
model_path = hypar [ " model_path " ]
model_save_fre = hypar [ " model_save_fre " ]
max_ite = hypar [ " max_ite " ]
batch_size_train = hypar [ " batch_size_train " ]
batch_size_valid = hypar [ " batch_size_valid " ]
if ( not os . path . exists ( model_path ) ) :
os . mkdir ( model_path )
ite_num = hypar [ " start_ite " ] # count the toal iteration number
ite_num4val = 0 #
running_loss = 0.0 # count the toal loss
running_tar_loss = 0.0 # count the target output loss
last_f1 = [ 0 for x in range ( len ( valid_dataloaders ) ) ]
train_num = train_datasets [ 0 ] . __len__ ( )
net . train ( )
start_last = time . time ( )
gos_dataloader = train_dataloaders [ 0 ]
epoch_num = hypar [ " max_epoch_num " ]
notgood_cnt = 0
for epoch in range ( epoch_num ) : ## set the epoch num as 100000
for i , data in enumerate ( gos_dataloader ) :
if ( ite_num > = max_ite ) :
print ( " Training Reached the Maximal Iteration Number " , max_ite )
exit ( )
# start_read = time.time()
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
# get the inputs
inputs , labels = data [ ' image ' ] , data [ ' label ' ]
if ( hypar [ " model_digit " ] == " full " ) :
inputs = inputs . type ( torch . FloatTensor )
labels = labels . type ( torch . FloatTensor )
else :
inputs = inputs . type ( torch . HalfTensor )
labels = labels . type ( torch . HalfTensor )
# wrap them in Variable
if torch . cuda . is_available ( ) :
inputs_v , labels_v = Variable ( inputs . cuda ( ) , requires_grad = False ) , Variable ( labels . cuda ( ) , requires_grad = False )
else :
inputs_v , labels_v = Variable ( inputs , requires_grad = False ) , Variable ( labels , requires_grad = False )
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
# y zero the parameter gradients
start_inf_loss_back = time . time ( )
optimizer . zero_grad ( )
if hypar [ " interm_sup " ] :
# forward + backward + optimize
ds , dfs = net ( inputs_v )
_ , fs = featurenet ( labels_v ) ## extract the gt encodings
loss2 , loss = net . compute_loss_kl ( ds , labels_v , dfs , fs , mode = ' MSE ' )
else :
# forward + backward + optimize
ds , _ = net ( inputs_v )
2022-07-17 20:00:12 +02:00
loss2 , loss = net . compute_loss ( ds , labels_v )
2022-07-17 07:56:37 +02:00
loss . backward ( )
optimizer . step ( )
# # print statistics
running_loss + = loss . item ( )
running_tar_loss + = loss2 . item ( )
# del outputs, loss
del ds , loss2 , loss
end_inf_loss_back = time . time ( ) - start_inf_loss_back
print ( " >>> " + model_path . split ( ' / ' ) [ - 1 ] + " - [epoch: %3d / %3d , batch: %5d / %5d , ite: %d ] train loss: %3f , tar: %3f , time-per-iter: %3f s, time_read: %3f " % (
epoch + 1 , epoch_num , ( i + 1 ) * batch_size_train , train_num , ite_num , running_loss / ite_num4val , running_tar_loss / ite_num4val , time . time ( ) - start_last , time . time ( ) - start_last - end_inf_loss_back ) )
start_last = time . time ( )
if ite_num % model_save_fre == 0 : # validate every 2000 iterations
notgood_cnt + = 1
net . eval ( )
tmp_f1 , tmp_mae , val_loss , tar_loss , i_val , tmp_time = valid ( net , valid_dataloaders , valid_datasets , hypar , epoch )
net . train ( ) # resume train
tmp_out = 0
print ( " last_f1: " , last_f1 )
print ( " tmp_f1: " , tmp_f1 )
for fi in range ( len ( last_f1 ) ) :
if ( tmp_f1 [ fi ] > last_f1 [ fi ] ) :
tmp_out = 1
print ( " tmp_out: " , tmp_out )
if ( tmp_out ) :
notgood_cnt = 0
last_f1 = tmp_f1
tmp_f1_str = [ str ( round ( f1x , 4 ) ) for f1x in tmp_f1 ]
tmp_mae_str = [ str ( round ( mx , 4 ) ) for mx in tmp_mae ]
maxf1 = ' _ ' . join ( tmp_f1_str )
meanM = ' _ ' . join ( tmp_mae_str )
# .cpu().detach().numpy()
model_name = " /gpu_itr_ " + str ( ite_num ) + \
" _traLoss_ " + str ( np . round ( running_loss / ite_num4val , 4 ) ) + \
" _traTarLoss_ " + str ( np . round ( running_tar_loss / ite_num4val , 4 ) ) + \
" _valLoss_ " + str ( np . round ( val_loss / ( i_val + 1 ) , 4 ) ) + \
" _valTarLoss_ " + str ( np . round ( tar_loss / ( i_val + 1 ) , 4 ) ) + \
" _maxF1_ " + maxf1 + \
" _mae_ " + meanM + \
" _time_ " + str ( np . round ( np . mean ( np . array ( tmp_time ) ) / batch_size_valid , 6 ) ) + " .pth "
torch . save ( net . state_dict ( ) , model_path + model_name )
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
if ( notgood_cnt > = hypar [ " early_stop " ] ) :
print ( " No improvements in the last " + str ( notgood_cnt ) + " validation periods, so training stopped ! " )
exit ( )
print ( " Training Reaches The Maximum Epoch Number " )
def valid ( net , valid_dataloaders , valid_datasets , hypar , epoch = 0 ) :
net . eval ( )
print ( " Validating... " )
epoch_num = hypar [ " max_epoch_num " ]
val_loss = 0.0
tar_loss = 0.0
val_cnt = 0.0
tmp_f1 = [ ]
tmp_mae = [ ]
tmp_time = [ ]
start_valid = time . time ( )
for k in range ( len ( valid_dataloaders ) ) :
valid_dataloader = valid_dataloaders [ k ]
valid_dataset = valid_datasets [ k ]
val_num = valid_dataset . __len__ ( )
mybins = np . arange ( 0 , 256 )
PRE = np . zeros ( ( val_num , len ( mybins ) - 1 ) )
REC = np . zeros ( ( val_num , len ( mybins ) - 1 ) )
F1 = np . zeros ( ( val_num , len ( mybins ) - 1 ) )
MAE = np . zeros ( ( val_num ) )
for i_val , data_val in enumerate ( valid_dataloader ) :
val_cnt = val_cnt + 1.0
imidx_val , inputs_val , labels_val , shapes_val = data_val [ ' imidx ' ] , data_val [ ' image ' ] , data_val [ ' label ' ] , data_val [ ' shape ' ]
if ( hypar [ " model_digit " ] == " full " ) :
inputs_val = inputs_val . type ( torch . FloatTensor )
labels_val = labels_val . type ( torch . FloatTensor )
else :
inputs_val = inputs_val . type ( torch . HalfTensor )
labels_val = labels_val . type ( torch . HalfTensor )
# wrap them in Variable
if torch . cuda . is_available ( ) :
inputs_val_v , labels_val_v = Variable ( inputs_val . cuda ( ) , requires_grad = False ) , Variable ( labels_val . cuda ( ) , requires_grad = False )
else :
inputs_val_v , labels_val_v = Variable ( inputs_val , requires_grad = False ) , Variable ( labels_val , requires_grad = False )
t_start = time . time ( )
ds_val = net ( inputs_val_v ) [ 0 ]
t_end = time . time ( ) - t_start
tmp_time . append ( t_end )
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
loss2_val , loss_val = net . compute_loss ( ds_val , labels_val_v )
# compute F measure
for t in range ( hypar [ " batch_size_valid " ] ) :
i_test = imidx_val [ t ] . data . numpy ( )
pred_val = ds_val [ 0 ] [ t , : , : , : ] # B x 1 x H x W
## recover the prediction spatial size to the orignal image size
pred_val = torch . squeeze ( F . upsample ( torch . unsqueeze ( pred_val , 0 ) , ( shapes_val [ t ] [ 0 ] , shapes_val [ t ] [ 1 ] ) , mode = ' bilinear ' ) )
# pred_val = normPRED(pred_val)
ma = torch . max ( pred_val )
mi = torch . min ( pred_val )
pred_val = ( pred_val - mi ) / ( ma - mi ) # max = 1
2022-07-17 22:17:27 +02:00
if len ( valid_dataset . dataset [ " ori_gt_path " ] ) != 0 :
gt = np . squeeze ( io . imread ( valid_dataset . dataset [ " ori_gt_path " ] [ i_test ] ) ) # max = 255
2022-11-04 09:10:53 +01:00
if gt . max ( ) == 1 :
gt = gt * 255
2022-07-17 22:17:27 +02:00
else :
gt = np . zeros ( ( shapes_val [ t ] [ 0 ] , shapes_val [ t ] [ 1 ] ) )
2022-07-17 07:56:37 +02:00
with torch . no_grad ( ) :
2022-07-17 11:22:08 +02:00
gt = torch . tensor ( gt ) . to ( device )
2022-07-17 07:56:37 +02:00
pre , rec , f1 , mae = f1_mae_torch ( pred_val * 255 , gt , valid_dataset , i_test , mybins , hypar )
PRE [ i_test , : ] = pre
REC [ i_test , : ] = rec
F1 [ i_test , : ] = f1
MAE [ i_test ] = mae
del ds_val , gt
gc . collect ( )
torch . cuda . empty_cache ( )
# if(loss_val.data[0]>1):
val_loss + = loss_val . item ( ) #data[0]
tar_loss + = loss2_val . item ( ) #data[0]
print ( " [validating: %5d / %5d ] val_ls: %f , tar_ls: %f , f1: %f , mae: %f , time: %f " % ( i_val , val_num , val_loss / ( i_val + 1 ) , tar_loss / ( i_val + 1 ) , np . amax ( F1 [ i_test , : ] ) , MAE [ i_test ] , t_end ) )
del loss2_val , loss_val
print ( ' ============================ ' )
PRE_m = np . mean ( PRE , 0 )
REC_m = np . mean ( REC , 0 )
f1_m = ( 1 + 0.3 ) * PRE_m * REC_m / ( 0.3 * PRE_m + REC_m + 1e-8 )
tmp_f1 . append ( np . amax ( f1_m ) )
tmp_mae . append ( np . mean ( MAE ) )
return tmp_f1 , tmp_mae , val_loss , tar_loss , i_val , tmp_time
def main ( train_datasets ,
valid_datasets ,
hypar ) : # model: "train", "test"
### --- Step 1: Build datasets and dataloaders ---
dataloaders_train = [ ]
dataloaders_valid = [ ]
if ( hypar [ " mode " ] == " train " ) :
print ( " --- create training dataloader --- " )
## collect training dataset
train_nm_im_gt_list = get_im_gt_name_dict ( train_datasets , flag = " train " )
## build dataloader for training datasets
train_dataloaders , train_datasets = create_dataloaders ( train_nm_im_gt_list ,
cache_size = hypar [ " cache_size " ] ,
cache_boost = hypar [ " cache_boost_train " ] ,
my_transforms = [
GOSRandomHFlip ( ) , ## this line can be uncommented for horizontal flip augmetation
# GOSResize(hypar["input_size"]),
# GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
GOSNormalize ( [ 0.5 , 0.5 , 0.5 ] , [ 1.0 , 1.0 , 1.0 ] ) ,
] ,
batch_size = hypar [ " batch_size_train " ] ,
shuffle = True )
train_dataloaders_val , train_datasets_val = create_dataloaders ( train_nm_im_gt_list ,
cache_size = hypar [ " cache_size " ] ,
cache_boost = hypar [ " cache_boost_train " ] ,
my_transforms = [
GOSNormalize ( [ 0.5 , 0.5 , 0.5 ] , [ 1.0 , 1.0 , 1.0 ] ) ,
] ,
batch_size = hypar [ " batch_size_valid " ] ,
shuffle = False )
print ( len ( train_dataloaders ) , " train dataloaders created " )
print ( " --- create valid dataloader --- " )
## build dataloader for validation or testing
valid_nm_im_gt_list = get_im_gt_name_dict ( valid_datasets , flag = " valid " )
## build dataloader for training datasets
valid_dataloaders , valid_datasets = create_dataloaders ( valid_nm_im_gt_list ,
cache_size = hypar [ " cache_size " ] ,
cache_boost = hypar [ " cache_boost_valid " ] ,
my_transforms = [
GOSNormalize ( [ 0.5 , 0.5 , 0.5 ] , [ 1.0 , 1.0 , 1.0 ] ) ,
# GOSResize(hypar["input_size"])
] ,
batch_size = hypar [ " batch_size_valid " ] ,
shuffle = False )
print ( len ( valid_dataloaders ) , " valid dataloaders created " )
# print(valid_datasets[0]["data_name"])
### --- Step 2: Build Model and Optimizer ---
print ( " --- build model --- " )
net = hypar [ " model " ] #GOSNETINC(3,1)
# convert to half precision
if ( hypar [ " model_digit " ] == " half " ) :
net . half ( )
for layer in net . modules ( ) :
if isinstance ( layer , nn . BatchNorm2d ) :
layer . float ( )
if torch . cuda . is_available ( ) :
net . cuda ( )
if ( hypar [ " restore_model " ] != " " ) :
print ( " restore model from: " )
print ( hypar [ " model_path " ] + " / " + hypar [ " restore_model " ] )
if torch . cuda . is_available ( ) :
net . load_state_dict ( torch . load ( hypar [ " model_path " ] + " / " + hypar [ " restore_model " ] ) )
else :
net . load_state_dict ( torch . load ( hypar [ " model_path " ] + " / " + hypar [ " restore_model " ] , map_location = " cpu " ) )
print ( " --- define optimizer --- " )
optimizer = optim . Adam ( net . parameters ( ) , lr = 1e-3 , betas = ( 0.9 , 0.999 ) , eps = 1e-08 , weight_decay = 0 )
### --- Step 3: Train or Valid Model ---
if ( hypar [ " mode " ] == " train " ) :
train ( net ,
optimizer ,
train_dataloaders ,
train_datasets ,
valid_dataloaders ,
valid_datasets ,
hypar ,
train_dataloaders_val , train_datasets_val )
else :
valid ( net ,
valid_dataloaders ,
valid_datasets ,
hypar )
if __name__ == " __main__ " :
### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
## configure the train, valid and inference datasets
train_datasets , valid_datasets = [ ] , [ ]
dataset_1 , dataset_1 = { } , { }
dataset_tr = { " name " : " DIS5K-TR " ,
" im_dir " : " ../DIS5K/DIS-TR/im " ,
" gt_dir " : " ../DIS5K/DIS-TR/gt " ,
" im_ext " : " .jpg " ,
" gt_ext " : " .png " ,
" cache_dir " : " ../DIS5K-Cache/DIS-TR " }
dataset_vd = { " name " : " DIS5K-VD " ,
" im_dir " : " ../DIS5K/DIS-VD/im " ,
" gt_dir " : " ../DIS5K/DIS-VD/gt " ,
" im_ext " : " .jpg " ,
" gt_ext " : " .png " ,
" cache_dir " : " ../DIS5K-Cache/DIS-VD " }
dataset_te1 = { " name " : " DIS5K-TE1 " ,
" im_dir " : " ../DIS5K/DIS-TE1/im " ,
" gt_dir " : " ../DIS5K/DIS-TE1/gt " ,
" im_ext " : " .jpg " ,
" gt_ext " : " .png " ,
" cache_dir " : " ../DIS5K-Cache/DIS-TE1 " }
dataset_te2 = { " name " : " DIS5K-TE2 " ,
" im_dir " : " ../DIS5K/DIS-TE2/im " ,
" gt_dir " : " ../DIS5K/DIS-TE2/gt " ,
" im_ext " : " .jpg " ,
" gt_ext " : " .png " ,
" cache_dir " : " ../DIS5K-Cache/DIS-TE2 " }
dataset_te3 = { " name " : " DIS5K-TE3 " ,
" im_dir " : " ../DIS5K/DIS-TE3/im " ,
" gt_dir " : " ../DIS5K/DIS-TE3/gt " ,
" im_ext " : " .jpg " ,
" gt_ext " : " .png " ,
" cache_dir " : " ../DIS5K-Cache/DIS-TE3 " }
dataset_te4 = { " name " : " DIS5K-TE4 " ,
" im_dir " : " ../DIS5K/DIS-TE4/im " ,
" gt_dir " : " ../DIS5K/DIS-TE4/gt " ,
" im_ext " : " .jpg " ,
" gt_ext " : " .png " ,
" cache_dir " : " ../DIS5K-Cache/DIS-TE4 " }
2022-07-17 22:17:27 +02:00
### test your own dataset
dataset_demo = { " name " : " your-dataset " ,
" im_dir " : " ../your-dataset/im " ,
" gt_dir " : " " ,
" im_ext " : " .jpg " ,
" gt_ext " : " " ,
" cache_dir " : " ../your-dataset/cache " }
2022-07-17 07:56:37 +02:00
train_datasets = [ dataset_tr ] ## users can create mutiple dictionary for setting a list of datasets as training set
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
2022-07-17 22:17:27 +02:00
valid_datasets = [ dataset_vd ] # dataset_vd, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
2022-07-17 07:56:37 +02:00
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
hypar = { }
## -- 2.1. configure the model saving or restoring path --
hypar [ " mode " ] = " train "
## "train": for training,
## "valid": for validation and inferening,
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
## otherwise only accuracy will be calculated and no predictions will be saved
2022-07-17 22:17:27 +02:00
hypar [ " interm_sup " ] = False ## in-dicate if activate intermediate feature supervision
2022-07-17 07:56:37 +02:00
if hypar [ " mode " ] == " train " :
hypar [ " valid_out_dir " ] = " " ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
hypar [ " model_path " ] = " ../saved_models/IS-Net-test " ## model weights saving (or restoring) path
hypar [ " restore_model " ] = " " ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
hypar [ " start_ite " ] = 0 ## start iteration for the training, can be changed to match the restored training process
hypar [ " gt_encoder_model " ] = " "
else : ## configure the segmentation output path and the to-be-used model weights path
2022-07-17 22:17:27 +02:00
hypar [ " valid_out_dir " ] = " ../your-results/ " ##"../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
hypar [ " model_path " ] = " ../saved_models/IS-Net " ## load trained weights from this path
hypar [ " restore_model " ] = " isnet.pth " ##"isnet.pth" ## name of the to-be-loaded weights
2022-07-17 07:56:37 +02:00
# if hypar["restore_model"]!="":
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
## -- 2.2. choose floating point accuracy --
hypar [ " model_digit " ] = " full " ## indicates "half" or "full" accuracy of float number
hypar [ " seed " ] = 0
## -- 2.3. cache data spatial size --
## To handle large size input images, which take a lot of time for loading in training,
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
hypar [ " cache_size " ] = [ 1024 , 1024 ] ## cached input spatial resolution, can be configured into different size
hypar [ " cache_boost_train " ] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
hypar [ " cache_boost_valid " ] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
## --- 2.4. data augmentation parameters ---
hypar [ " input_size " ] = [ 1024 , 1024 ] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
hypar [ " crop_size " ] = [ 1024 , 1024 ] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
hypar [ " random_flip_h " ] = 1 ## horizontal flip, currently hard coded in the dataloader and it is not in use
hypar [ " random_flip_v " ] = 0 ## vertical flip , currently not in use
## --- 2.5. define model ---
print ( " building model... " )
hypar [ " model " ] = ISNetDIS ( ) #U2NETFASTFEATURESUP()
hypar [ " early_stop " ] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
2022-07-18 01:19:20 +02:00
hypar [ " model_save_fre " ] = 2000 ## valid and save model weights every 2000 iterations
2022-07-17 07:56:37 +02:00
2022-07-18 01:19:20 +02:00
hypar [ " batch_size_train " ] = 8 ## batch size for training
2022-07-17 07:56:37 +02:00
hypar [ " batch_size_valid " ] = 1 ## batch size for validation and inferencing
print ( " batch size: " , hypar [ " batch_size_train " ] )
hypar [ " max_ite " ] = 10000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
hypar [ " max_epoch_num " ] = 1000000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
main ( train_datasets ,
valid_datasets ,
2022-11-04 09:10:53 +01:00
hypar = hypar )