mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 16:53:17 +01:00
Add Replicate demo and API
This commit is contained in:
parent
7df2a36362
commit
3c18906552
16
IS-Net/cog.yaml
Normal file
16
IS-Net/cog.yaml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
build:
|
||||||
|
gpu: true
|
||||||
|
python_version: 3.8
|
||||||
|
system_packages:
|
||||||
|
- libgl1-mesa-glx
|
||||||
|
- libglib2.0-0
|
||||||
|
python_packages:
|
||||||
|
- torch==1.9.0
|
||||||
|
- torchvision==0.10.0
|
||||||
|
- numpy==1.21.1
|
||||||
|
- opencv-python==4.5.5.64
|
||||||
|
- matplotlib==3.5.1
|
||||||
|
- tqdm==4.63.1
|
||||||
|
- scikit-image==0.19.2
|
||||||
|
|
||||||
|
predict: "predict.py:Predictor"
|
74
IS-Net/predict.py
Normal file
74
IS-Net/predict.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from data_loader_cache import normalize, im_reader, im_preprocess
|
||||||
|
from models.isnet import ISNetDIS
|
||||||
|
|
||||||
|
from cog import BasePredictor, Path, Input
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
class Predictor(BasePredictor):
|
||||||
|
def setup(self):
|
||||||
|
self.net = ISNetDIS()
|
||||||
|
self.net.load_state_dict(torch.load("isnet.pth", map_location=device))
|
||||||
|
self.net.to(device)
|
||||||
|
self.net.eval()
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
input_image: Path = Input(description="Image to segment."),
|
||||||
|
|
||||||
|
) -> Path:
|
||||||
|
cache_size = [1024,1024]
|
||||||
|
image, orig_size = load_image(str(input_image), cache_size)
|
||||||
|
|
||||||
|
image = image.type(torch.FloatTensor)
|
||||||
|
|
||||||
|
image = Variable(image, requires_grad=False).to(device) # wrap inputs in Variable
|
||||||
|
|
||||||
|
ds_val = self.net(image)[0] # list of 6 results
|
||||||
|
|
||||||
|
pred_val = ds_val[0][0, :, :, :] # B x 1 x H x W # we want the first one which is the most accurate prediction
|
||||||
|
|
||||||
|
## recover the prediction spatial size to the orignal image size
|
||||||
|
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (orig_size[0][0], orig_size[0][1]), mode='bilinear'))
|
||||||
|
|
||||||
|
ma = torch.max(pred_val)
|
||||||
|
mi = torch.min(pred_val)
|
||||||
|
pred_val = (pred_val - mi) / (ma - mi) # max = 1
|
||||||
|
|
||||||
|
if device == 'cuda':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
output_path = "output.png"
|
||||||
|
save_image(pred_val, output_path, normalize=True)
|
||||||
|
|
||||||
|
return Path(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
class GOSNormalize(object):
|
||||||
|
'''
|
||||||
|
Normalize the Image using torch.transforms
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def __call__(self, image):
|
||||||
|
image = normalize(image, self.mean, self.std)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(im_path, cache_size):
|
||||||
|
im = im_reader(im_path)
|
||||||
|
im, im_shp = im_preprocess(im, cache_size)
|
||||||
|
im = torch.divide(im, 255.0)
|
||||||
|
shape = torch.from_numpy(np.array(im_shp))
|
||||||
|
transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
|
||||||
|
return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
|
@ -12,9 +12,9 @@
|
|||||||
<br>
|
<br>
|
||||||
|
|
||||||
## This is the official repo for our newly formulated DIS task:
|
## This is the official repo for our newly formulated DIS task:
|
||||||
[**Project Page**](https://xuebinqin.github.io/dis/index.html), [**Arxiv**](https://arxiv.org/pdf/2203.03041.pdf), [**中文**](https://github.com/xuebinqin/xuebinqin.github.io/blob/main/ECCV2022_DIS_Chinese.pdf).
|
[[**Project Page**]](https://xuebinqin.github.io/dis/index.html), [[**Arxiv**]](https://arxiv.org/pdf/2203.03041.pdf), [[**中文**]](https://github.com/xuebinqin/xuebinqin.github.io/blob/main/ECCV2022_DIS_Chinese.pdf).
|
||||||
|
[![asd](https://replicate.com/arielreplicate/dichotomous_image_segmentation/badge)](https://replicate.com/arielreplicate/dichotomous_image_segmentation)
|
||||||
<br>
|
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/doevent/dis-background-removal)
|
||||||
|
|
||||||
# PLEASE STAY TUNED FOR OUR DIS V2.0 (Jul. 30th, 2022)
|
# PLEASE STAY TUNED FOR OUR DIS V2.0 (Jul. 30th, 2022)
|
||||||
![disv2-peacock](figures/peacock.jpg)
|
![disv2-peacock](figures/peacock.jpg)
|
||||||
|
Loading…
Reference in New Issue
Block a user