diff --git a/Colab_Demo.ipynb b/Colab_Demo.ipynb
new file mode 100644
index 0000000..7fb6aef
--- /dev/null
+++ b/Colab_Demo.ipynb
@@ -0,0 +1,391 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Clone official repo"
+ ],
+ "metadata": {
+ "id": "P1rhi9xgJR-x"
+ },
+ "id": "P1rhi9xgJR-x"
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "! git clone https://github.com/xuebinqin/DIS\n",
+ "\n",
+ "%cd ./DIS/IS-Net\n",
+ "\n",
+ "!pip install gdown"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "wlRB0Pq0JIvF",
+ "outputId": "8b8e5619-4c39-46b6-8e3c-520c4b68cb23"
+ },
+ "id": "wlRB0Pq0JIvF",
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Cloning into 'DIS'...\n",
+ "remote: Enumerating objects: 151, done.\u001b[K\n",
+ "remote: Counting objects: 100% (116/116), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (83/83), done.\u001b[K\n",
+ "remote: Total 151 (delta 35), reused 108 (delta 30), pack-reused 35\u001b[K\n",
+ "Receiving objects: 100% (151/151), 43.23 MiB | 34.50 MiB/s, done.\n",
+ "Resolving deltas: 100% (37/37), done.\n",
+ "/content/DIS/IS-Net\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Requirement already satisfied: gdown in /usr/local/lib/python3.7/dist-packages (4.4.0)\n",
+ "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.7/dist-packages (from gdown) (4.6.3)\n",
+ "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from gdown) (1.15.0)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from gdown) (4.64.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from gdown) (3.7.1)\n",
+ "Requirement already satisfied: requests[socks] in /usr/local/lib/python3.7/dist-packages (from gdown) (2.23.0)\n",
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (3.0.4)\n",
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.24.3)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2022.6.15)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2.10)\n",
+ "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.7.1)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Imports"
+ ],
+ "metadata": {
+ "id": "RO0DY6O3Jqe9"
+ },
+ "id": "RO0DY6O3Jqe9"
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch\n",
+ "from torch.autograd import Variable\n",
+ "from torchvision import transforms\n",
+ "import torch.nn.functional as F\n",
+ "import gdown\n",
+ "import os\n",
+ "\n",
+ "import requests\n",
+ "import matplotlib.pyplot as plt\n",
+ "from io import BytesIO\n",
+ "\n",
+ "# project imports\n",
+ "from data_loader_cache import normalize, im_reader, im_preprocess \n",
+ "from models import *\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "9fFNd2X_Js0e",
+ "outputId": "80e92881-6893-4227-a7a7-4f39a566061f"
+ },
+ "id": "9fFNd2X_Js0e",
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n",
+ " warnings.warn(warning.format(ret))\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Helpers"
+ ],
+ "metadata": {
+ "id": "h1C9zSdkJgtF"
+ },
+ "id": "h1C9zSdkJgtF"
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "\n",
+ "# Download official weights\n",
+ "if not os.path.exists(\"./saved_models\"):\n",
+ " !mkdir ./saved_models\n",
+ " MODEL_PATH_URL = \"https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn\"\n",
+ " gdown.download(MODEL_PATH_URL, \"./saved_models/isnet.pth\", use_cookies=False)\n",
+ "\n",
+ "\n",
+ "class GOSNormalize(object):\n",
+ " '''\n",
+ " Normalize the Image using torch.transforms\n",
+ " '''\n",
+ " def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):\n",
+ " self.mean = mean\n",
+ " self.std = std\n",
+ "\n",
+ " def __call__(self,image):\n",
+ " image = normalize(image,self.mean,self.std)\n",
+ " return image\n",
+ "\n",
+ "\n",
+ "transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])\n",
+ "\n",
+ "def load_image(im_path, hypar):\n",
+ " if im_path.startswith(\"http\"):\n",
+ " im_path = BytesIO(requests.get(im_path).content)\n",
+ "\n",
+ " im = im_reader(im_path)\n",
+ " im, im_shp = im_preprocess(im, hypar[\"cache_size\"])\n",
+ " im = torch.divide(im,255.0)\n",
+ " shape = torch.from_numpy(np.array(im_shp))\n",
+ " return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape\n",
+ "\n",
+ "\n",
+ "def build_model(hypar,device):\n",
+ " net = hypar[\"model\"]#GOSNETINC(3,1)\n",
+ "\n",
+ " # convert to half precision\n",
+ " if(hypar[\"model_digit\"]==\"half\"):\n",
+ " net.half()\n",
+ " for layer in net.modules():\n",
+ " if isinstance(layer, nn.BatchNorm2d):\n",
+ " layer.float()\n",
+ "\n",
+ " net.to(device)\n",
+ "\n",
+ " if(hypar[\"restore_model\"]!=\"\"):\n",
+ " net.load_state_dict(torch.load(hypar[\"model_path\"]+\"/\"+hypar[\"restore_model\"],map_location=device))\n",
+ " net.to(device)\n",
+ " net.eval() \n",
+ " return net\n",
+ "\n",
+ " \n",
+ "def predict(net, inputs_val, shapes_val, hypar, device):\n",
+ " '''\n",
+ " Given an Image, predict the mask\n",
+ " '''\n",
+ " net.eval()\n",
+ "\n",
+ " if(hypar[\"model_digit\"]==\"full\"):\n",
+ " inputs_val = inputs_val.type(torch.FloatTensor)\n",
+ " else:\n",
+ " inputs_val = inputs_val.type(torch.HalfTensor)\n",
+ "\n",
+ " \n",
+ " inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable\n",
+ " \n",
+ " ds_val = net(inputs_val_v)[0] # list of 6 results\n",
+ "\n",
+ " pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction\n",
+ "\n",
+ " ## recover the prediction spatial size to the orignal image size\n",
+ " pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))\n",
+ "\n",
+ " ma = torch.max(pred_val)\n",
+ " mi = torch.min(pred_val)\n",
+ " pred_val = (pred_val-mi)/(ma-mi) # max = 1\n",
+ "\n",
+ " if device == 'cuda': torch.cuda.empty_cache()\n",
+ " return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BFVvxhZQJkEy",
+ "outputId": "e8271029-16de-4501-a579-c04bd16aa420"
+ },
+ "id": "BFVvxhZQJkEy",
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Downloading...\n",
+ "From: https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn\n",
+ "To: /content/DIS/IS-Net/saved_models/isnet.pth\n",
+ "100%|██████████| 177M/177M [00:03<00:00, 56.6MB/s]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Set Parameters"
+ ],
+ "metadata": {
+ "id": "H7OQxVqaOgtk"
+ },
+ "id": "H7OQxVqaOgtk"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "189b719a-c9a2-4048-8620-0501fd5653ec",
+ "metadata": {
+ "id": "189b719a-c9a2-4048-8620-0501fd5653ec"
+ },
+ "outputs": [],
+ "source": [
+ "hypar = {} # paramters for inferencing\n",
+ "\n",
+ "\n",
+ "hypar[\"model_path\"] =\"./saved_models\" ## load trained weights from this path\n",
+ "hypar[\"restore_model\"] = \"isnet.pth\" ## name of the to-be-loaded weights\n",
+ "hypar[\"interm_sup\"] = False ## indicate if activate intermediate feature supervision\n",
+ "\n",
+ "## choose floating point accuracy --\n",
+ "hypar[\"model_digit\"] = \"full\" ## indicates \"half\" or \"full\" accuracy of float number\n",
+ "hypar[\"seed\"] = 0\n",
+ "\n",
+ "hypar[\"cache_size\"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size\n",
+ "\n",
+ "## data augmentation parameters ---\n",
+ "hypar[\"input_size\"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar[\"cache_size\"], which means we don't further resize the images\n",
+ "hypar[\"crop_size\"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar[\"cache_size\"], e.g., [920,920] for data augmentation\n",
+ "\n",
+ "hypar[\"model\"] = ISNetDIS()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0af5269e-26a6-4370-8863-92b7381ee90f",
+ "metadata": {
+ "tags": [],
+ "id": "0af5269e-26a6-4370-8863-92b7381ee90f"
+ },
+ "source": [
+ "# Build Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "b23ea487-1f64-4443-95b4-7998b5345310",
+ "metadata": {
+ "id": "b23ea487-1f64-4443-95b4-7998b5345310"
+ },
+ "outputs": [],
+ "source": [
+ "net = build_model(hypar, device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8beb1f62-0345-4c82-a2e3-9a4db55a55a2",
+ "metadata": {
+ "id": "8beb1f62-0345-4c82-a2e3-9a4db55a55a2"
+ },
+ "source": [
+ "# Predict Mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "036b21e8-556b-43dd-b9fb-1ea085f7f0f1",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 753
+ },
+ "id": "036b21e8-556b-43dd-b9fb-1ea085f7f0f1",
+ "outputId": "1ce15be9-9287-498a-82ec-637b4473abb2"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3722: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\n",
+ " warnings.warn(\"nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\")\n",
+ "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1960: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
+ " warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "