Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

object detection models in R? #54

Open
mikeyEcology opened this issue Jul 22, 2021 · 11 comments
Open

object detection models in R? #54

mikeyEcology opened this issue Jul 22, 2021 · 11 comments

Comments

@mikeyEcology
Copy link

Hi,
I'd like to deploy an object detection model in R and I'm wondering if there is a function to do this? I see that there are some object classification models available, so I'm not sure if I'm missing them or if they're not available. I would like to use faster rcnn analogous to the one implemented in pytorch. Is this available? If not, do you have any tips for me in trying to develop this?
Thank you

@dfalbel
Copy link
Member

dfalbel commented Jul 22, 2021

Hi @mikeyEcology

We still haven't implemented those for R. To my knowledge there's a yolo implementation here: https://github.com/openvolley/ovml/blob/master/R/yolo.R

If you want to port them I'd start by implementing the generalized rcnn module:

https://github.com/pytorch/vision/blob/master/torchvision/models/detection/generalized_rcnn.py

and then go to the specialization:

https://github.com/pytorch/vision/blob/master/torchvision/models/detection/faster_rcnn.py

you will also probably need to convert this: https://github.com/pytorch/vision/blob/master/torchvision/models/detection/backbone_utils.py

In general, translation these requires basic knowledge of Python, to be able to convert the loops and if statements. Also, you should also not need to care about scripting, so you can ignore torch.is_scripting() branches.

@mikeyEcology
Copy link
Author

Thank you @dfalbel . We'll get working on it. I trained the model in python, but I'm trying to make it deployable in R as I'm making an R package for ecologists (to evaluate camera trap images). After making the package, we'll be writing this up as a manuscript. If you're interested in participating in writing this code I'd be happy to include you as a co-author. Please let me know if you're interested: [email protected]

@dfalbel
Copy link
Member

dfalbel commented Jul 26, 2021

Since you already trained the model in python you could jit save it with torch.jit.save(model, "path/to/file") and reload it in R with jit_load(). This way, the architecture is also serialized, so you wouldn't need to port the code to R.

This requires dev version of torch though.

@mikeyEcology
Copy link
Author

Thank you @dfalbel ! This looks like a great option. It would be a huge advantage to not have to port the code to R. I have tried this, but I got an error when trying to load into R because it cannot load the nms function. Any suggestions for this?
I could port the function torchvision.ops.nms from python to R, but I'm not sure how this would work with jit loading (this is the first I've used jit).

In python, I ran:

m = torch.jit.script(model.to(device='cpu'))
torch.jit.save(m, path2jit)

Then in R I ran:

torch::jit_load(path2jit)

But I get this error:

Error in cpp_jit_load(path) :
Unknown builtin op: torchvision::nms.
Could not find any similar ops to torchvision::nms. This op may not exist or may not be currently supported in TorchScript.
:
File "/home/mtabak/bin/anaconda3/envs/pyt/lib/python3.8/site-packages/torchvision/ops/boxes.py", line 42
"""
_assert_has_ops()
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Serialized File "code/torch/torchvision/ops/boxes.py", line 93
_42 = torch.torchvision.extension._assert_has_ops
_43 = _42()
_44 = ops.torchvision.nms(boxes, scores, iou_threshold)
~~~~~~~~~~~~~~~~~~~ <--- HERE
return _44
'nms' is being compiled since it was called from 'batched_nms'
File "/home/mtabak/bin/anaconda3/envs/pyt/lib/python3.8/site-packages/torchvision/ops/boxes.py", line 88
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
keep = nms(boxes_for

@dfalbel
Copy link
Member

dfalbel commented Jul 27, 2021

Hmm this is tricky. It seems that torchvision register some custom TorchScript operations and they wouldn't be available at the time we jit_load() in R. We will need to figure out how to register these custom operators from within R and it will be quite tricky right now.

See eg: https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html#using-the-torchscript-custom-operator-in-c

This is definitely on my todo list for the next few months, but will probably take some time to figure out the best way to do it.

@mikeyEcology
Copy link
Author

Yes-I can understand why this will be tricky. We'll be working on the same problem on our end as we need to include this in an R package. Please let me know if you'd like to collaborate on the effort. Clearly, you are much better at this than us, but our team includes a talented programmer who can also contribute.
If we work together, I'd be more than happy to include you as a co-author on the paper that comes from our R package. These papers tend to get cited a lot (e.g., our previous paper).

@dfalbel
Copy link
Member

dfalbel commented Jul 27, 2021

Sure I am happy to collaborate!
FWIW I was able to load the fasterrcnn model in a not too complicated way by doing the following:

  • Cloned https://github.com/pytorch/vision
  • mkdir build && cd build
  • cmake .. -DCMAKE_PREFIX_PATH=../torch/lantern/build/libtorch whe DCMAKE_PREFIX_PATH is a path a C++ libtorch binary distribution downloaded from https://pytorch.org/get-started/locally/
  • cmake --build . --parallel 8
  • make DESTDIR=/Users/dfalbel/Documents/libtorchvision install - this will create the ~/Documents/libtorchvision/usr/local/lib/libtorchvision.dylib file.

Now, from R:

library(torch)
dyn.load("~/Documents/libtorchvision/usr/local/lib/libtorchvision.dylib")
model <- jit_load("~/Downloads/fasterrcnn.pt")
inp <- list(torch_rand(3, 300, 400), torch_rand(3, 500, 400))
out <- model(inp)

I am not sure if this would work on Windows, but sound like it should just work.

The model I saved from python using:

import torchvision
import torch

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, pretrained_backbone=True)
model.eval()
s = torch.jit.script(model)

torch.jit.save(s, "../fasterrcnn.pt")

So In theory, provided that this works on Windows what we could do to support running these models from R is:

  • Create infrastructure to build torchvision binaries for every supported OS. Probably just a GH Actions workflow that builds torchvision and uploads the binaries to GH releases.
  • The torchvision R package could then, at .onLoad download those binaries and do dyn.load() to register the TorchScript operations, so users don't have to that manually.

@dfalbel dfalbel reopened this Jul 27, 2021
@mikeyEcology
Copy link
Author

Thank you. This looks like a great option. I have not been able to build and create that file. Would you be willing to share the libtorchvision.dylib file that you downloaded ([email protected])? I'll keep trying in the meantime

@mikeyEcology
Copy link
Author

I'm having a new problem when I try to load a state_dict from python into R. I'm not having problems with the .pt file that I created above. I'm wondering if this is due to the pytorch version. I'm currently using 1.9.1 (previously 1.7). I'm using R torch version 0.6.0; R version 4.1.1.
Here is the code I'm using to save a model in python

import torchvision
import torch

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# ... train model...
model.eval()
s = torch.jit.script(model.to(device = 'cpu'))
torch.jit.save(s, output_path + "weights_R.pt")

Then to load the model in R:

path2weights <- '.../weights_R.pt'
state_dict <- torch::load_state_dict(path2weights)

I get this error:

Error in cpp_load_state_dict(path) :
type_resolver_INTERNAL ASSERT FAILED at "..\..\torch\csrc\jit\serialization\unpickler.cpp":621, please report a bug to PyTorch.
Exception raised from readGlobal at ....\torch\csrc\jit\serialization\unpickler.cpp:621 (most recent call first):
00007FFD37E610D200007FFD37E61070 c10.dll!c10::Error::Error [ @ ]
00007FFD37E60B3E00007FFD37E60AF0 c10.dll!c10::detail::torchCheckFail [ @ ]
00007FFD37E3795900007FFD37E37950 c10.dll!c10::detail::torchInternalAssertFail [ @ ]
00007FFCBF198FA400007FFCBF1982E0 torch_cpu.dll!torch::jit::Unpickler::readGlobal [ @ ]
00007FFCBF19A1F500007FFCBF199240 torch_cpu.dll!torch::jit::Unpickler::readInstruction [ @ ]
00007FFCBF19D11A00007FFCBF19CFE0 torch_cpu.dll!torch::jit::Unpickler::run [ @ ]
00007FFCBF197FEC00007FFCBF197FC0 torch_cpu.dll!t

Any suggestions? Should I try reverting to an older version of pytorch?

@dfalbel
Copy link
Member

dfalbel commented Oct 20, 2021

I think you want to use jit_load instead of load_state_dict as AFAICT torch.jit.save is saving the entire model in the TorchScript format.

@mikeyEcology
Copy link
Author

Thank you! My mistake. I was trying to load the architecture as if it was the weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants