Skip to content

Simple formatting with Black, CPU support for inference and forgotten main function in training script#3

Open
aquadzn wants to merge 5 commits into
xuebinqin:masterfrom
aquadzn:master
Open

Simple formatting with Black, CPU support for inference and forgotten main function in training script#3
aquadzn wants to merge 5 commits into
xuebinqin:masterfrom
aquadzn:master

Conversation

@aquadzn
Copy link
Copy Markdown

@aquadzn aquadzn commented May 10, 2020

Hello, thank you for uploading your code.

I've made a few small changes.

Comment thread .gitignore
Comment on lines +1 to +5
__pycache__
*/__pycache__
**/__pycache__
saved_models/
.vscode No newline at end of file
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simple .gitignore

Comment thread u2net_test.py
Comment on lines +2 to +17
import glob
import time

import numpy as np
from PIL import Image
from skimage import io, transform

import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob
# import torch.optim as optim
import torchvision
from torchvision import transforms # , utils
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prettier import statements

Comment thread u2net_test.py
Comment on lines 91 to +96
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
net.cuda()
else:
net.load_state_dict(torch.load(model_dir, map_location=torch.device("cpu")))

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you try to load_state_dict on CPU without mapping location to CPU, you will have a RuntimeError
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Copy link
Copy Markdown

@MatiasConTilde MatiasConTilde May 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried this with torch==0.4.0 and it gives me TypeError: 'torch.Device' object is not callable. Solved it by upgrading to torch==0.4.1, so this should also be updated in the README

Copy link
Copy Markdown
Author

@aquadzn aquadzn May 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MatiasConTilde this should also work on torch>=0.4 with map_location="cpu"

941b8dd

Comment thread u2net_test.py
Comment on lines +102 to +116
start = time.time()

inputs_test = data_test['image']
inputs_test = data_test["image"]
inputs_test = inputs_test.type(torch.FloatTensor)

if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)

d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

print(
f"Predicted {os.path.basename(img_name_list[i_test])} in {time.time() - start:.2f}s"
)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Comment thread u2net_train.py
return loss0, loss


def main():
Copy link
Copy Markdown
Author

@aquadzn aquadzn May 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you forgot to add the main function

Copy link
Copy Markdown
Author

@aquadzn aquadzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except for the commented lines, the rest is just reformatting with black.

@aquadzn aquadzn changed the title Simple formatting with Black, CPU support for inference and forgotten main function in training file Simple formatting with Black, CPU support for inference and forgotten main function in training script May 10, 2020
should work with torch>=0.4
@aquadzn
Copy link
Copy Markdown
Author

aquadzn commented May 12, 2020

@Nathanua ?

@xuebinqin
Copy link
Copy Markdown
Owner

thanks for your contribution we are reviewing and testing it. Will update later.

@aquadzn
Copy link
Copy Markdown
Author

aquadzn commented May 14, 2020

Also, adding with torch.no_grad(): before enumerate(*_loader) and removing Variable for torch.Tensor reduces memory usage by a few hundreds of MBs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants