mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 08:43:17 +01:00
Add Replicate demo and API
This commit is contained in:
parent
7df2a36362
commit
3c18906552
16
IS-Net/cog.yaml
Normal file
16
IS-Net/cog.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
build:
|
||||
gpu: true
|
||||
python_version: 3.8
|
||||
system_packages:
|
||||
- libgl1-mesa-glx
|
||||
- libglib2.0-0
|
||||
python_packages:
|
||||
- torch==1.9.0
|
||||
- torchvision==0.10.0
|
||||
- numpy==1.21.1
|
||||
- opencv-python==4.5.5.64
|
||||
- matplotlib==3.5.1
|
||||
- tqdm==4.63.1
|
||||
- scikit-image==0.19.2
|
||||
|
||||
predict: "predict.py:Predictor"
|
74
IS-Net/predict.py
Normal file
74
IS-Net/predict.py
Normal file
@ -0,0 +1,74 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import save_image
|
||||
import torch.nn.functional as F
|
||||
from data_loader_cache import normalize, im_reader, im_preprocess
|
||||
from models.isnet import ISNetDIS
|
||||
|
||||
from cog import BasePredictor, Path, Input
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
self.net = ISNetDIS()
|
||||
self.net.load_state_dict(torch.load("isnet.pth", map_location=device))
|
||||
self.net.to(device)
|
||||
self.net.eval()
|
||||
|
||||
def predict(
|
||||
self,
|
||||
input_image: Path = Input(description="Image to segment."),
|
||||
|
||||
) -> Path:
|
||||
cache_size = [1024,1024]
|
||||
image, orig_size = load_image(str(input_image), cache_size)
|
||||
|
||||
image = image.type(torch.FloatTensor)
|
||||
|
||||
image = Variable(image, requires_grad=False).to(device) # wrap inputs in Variable
|
||||
|
||||
ds_val = self.net(image)[0] # list of 6 results
|
||||
|
||||
pred_val = ds_val[0][0, :, :, :] # B x 1 x H x W # we want the first one which is the most accurate prediction
|
||||
|
||||
## recover the prediction spatial size to the orignal image size
|
||||
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (orig_size[0][0], orig_size[0][1]), mode='bilinear'))
|
||||
|
||||
ma = torch.max(pred_val)
|
||||
mi = torch.min(pred_val)
|
||||
pred_val = (pred_val - mi) / (ma - mi) # max = 1
|
||||
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
output_path = "output.png"
|
||||
save_image(pred_val, output_path, normalize=True)
|
||||
|
||||
return Path(output_path)
|
||||
|
||||
|
||||
class GOSNormalize(object):
|
||||
'''
|
||||
Normalize the Image using torch.transforms
|
||||
'''
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image):
|
||||
image = normalize(image, self.mean, self.std)
|
||||
return image
|
||||
|
||||
|
||||
def load_image(im_path, cache_size):
|
||||
im = im_reader(im_path)
|
||||
im, im_shp = im_preprocess(im, cache_size)
|
||||
im = torch.divide(im, 255.0)
|
||||
shape = torch.from_numpy(np.array(im_shp))
|
||||
transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
|
||||
return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
|
@ -12,9 +12,9 @@
|
||||
<br>
|
||||
|
||||
## This is the official repo for our newly formulated DIS task:
|
||||
[**Project Page**](https://xuebinqin.github.io/dis/index.html), [**Arxiv**](https://arxiv.org/pdf/2203.03041.pdf), [**中文**](https://github.com/xuebinqin/xuebinqin.github.io/blob/main/ECCV2022_DIS_Chinese.pdf).
|
||||
|
||||
<br>
|
||||
[[**Project Page**]](https://xuebinqin.github.io/dis/index.html), [[**Arxiv**]](https://arxiv.org/pdf/2203.03041.pdf), [[**中文**]](https://github.com/xuebinqin/xuebinqin.github.io/blob/main/ECCV2022_DIS_Chinese.pdf).
|
||||
[![asd](https://replicate.com/arielreplicate/dichotomous_image_segmentation/badge)](https://replicate.com/arielreplicate/dichotomous_image_segmentation)
|
||||
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/doevent/dis-background-removal)
|
||||
|
||||
# PLEASE STAY TUNED FOR OUR DIS V2.0 (Jul. 30th, 2022)
|
||||
![disv2-peacock](figures/peacock.jpg)
|
||||
|
Loading…
Reference in New Issue
Block a user