Add Replicate demo and API

This commit is contained in:
ariel 2022-10-28 13:13:24 +03:00
parent 7df2a36362
commit 3c18906552
3 changed files with 93 additions and 3 deletions

16
IS-Net/cog.yaml Normal file
View File

@ -0,0 +1,16 @@
build:
gpu: true
python_version: 3.8
system_packages:
- libgl1-mesa-glx
- libglib2.0-0
python_packages:
- torch==1.9.0
- torchvision==0.10.0
- numpy==1.21.1
- opencv-python==4.5.5.64
- matplotlib==3.5.1
- tqdm==4.63.1
- scikit-image==0.19.2
predict: "predict.py:Predictor"

74
IS-Net/predict.py Normal file
View File

@ -0,0 +1,74 @@
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
from data_loader_cache import normalize, im_reader, im_preprocess
from models.isnet import ISNetDIS
from cog import BasePredictor, Path, Input
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Predictor(BasePredictor):
def setup(self):
self.net = ISNetDIS()
self.net.load_state_dict(torch.load("isnet.pth", map_location=device))
self.net.to(device)
self.net.eval()
def predict(
self,
input_image: Path = Input(description="Image to segment."),
) -> Path:
cache_size = [1024,1024]
image, orig_size = load_image(str(input_image), cache_size)
image = image.type(torch.FloatTensor)
image = Variable(image, requires_grad=False).to(device) # wrap inputs in Variable
ds_val = self.net(image)[0] # list of 6 results
pred_val = ds_val[0][0, :, :, :] # B x 1 x H x W # we want the first one which is the most accurate prediction
## recover the prediction spatial size to the orignal image size
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (orig_size[0][0], orig_size[0][1]), mode='bilinear'))
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val - mi) / (ma - mi) # max = 1
if device == 'cuda':
torch.cuda.empty_cache()
output_path = "output.png"
save_image(pred_val, output_path, normalize=True)
return Path(output_path)
class GOSNormalize(object):
'''
Normalize the Image using torch.transforms
'''
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, image):
image = normalize(image, self.mean, self.std)
return image
def load_image(im_path, cache_size):
im = im_reader(im_path)
im, im_shp = im_preprocess(im, cache_size)
im = torch.divide(im, 255.0)
shape = torch.from_numpy(np.array(im_shp))
transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape

View File

@ -12,9 +12,9 @@
<br> <br>
## This is the official repo for our newly formulated DIS task: ## This is the official repo for our newly formulated DIS task:
[**Project Page**](https://xuebinqin.github.io/dis/index.html), [**Arxiv**](https://arxiv.org/pdf/2203.03041.pdf), [**中文**](https://github.com/xuebinqin/xuebinqin.github.io/blob/main/ECCV2022_DIS_Chinese.pdf). [[**Project Page**]](https://xuebinqin.github.io/dis/index.html), [[**Arxiv**]](https://arxiv.org/pdf/2203.03041.pdf), [[**中文**]](https://github.com/xuebinqin/xuebinqin.github.io/blob/main/ECCV2022_DIS_Chinese.pdf).
[![asd](https://replicate.com/arielreplicate/dichotomous_image_segmentation/badge)](https://replicate.com/arielreplicate/dichotomous_image_segmentation)
<br> [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/doevent/dis-background-removal)
# PLEASE STAY TUNED FOR OUR DIS V2.0 (Jul. 30th, 2022) # PLEASE STAY TUNED FOR OUR DIS V2.0 (Jul. 30th, 2022)
![disv2-peacock](figures/peacock.jpg) ![disv2-peacock](figures/peacock.jpg)