Skip to content

Commit

Permalink
Merge branch 'main' of github.com:spfrommer/torchexplorer
Browse files Browse the repository at this point in the history
  • Loading branch information
spfrommer committed Dec 30, 2023
2 parents 3b0b683 + 95f270f commit 5e42e15
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,26 @@ torchexplorer.watch(model, backend='wandb') # Or 'standalone'
### Install
Installing requires one external `graphviz` dependency, which should be available on most package managers.

_Linux_.
```bash
sudo apt-get install libgraphviz-dev graphviz
pip install torchexplorer
```
For Mac, `brew install graphviz` should suffice. If the `pygraphviz` wheel build fails because it can't find `Python.h`, you must install the python header files as described [here](https://stackoverflow.com/a/22077790/4864247).
If the `pygraphviz` wheel build fails because it can't find `Python.h`, you must install the python header files as described [here](https://stackoverflow.com/a/22077790/4864247).

_Mac_.
```bash
brew install graphviz
pip install torchexplorer
```
If there's an error regarding `#include "graphviz/cgraph.h"`, [the following](https://github.com/pygraphviz/pygraphviz/issues/11#issuecomment-1038479834) worked for me on Apple silicon:
```bash
python -m pip install \
--global-option=build_ext \
--global-option="-I$(brew --prefix graphviz)/include/" \
--global-option="-L$(brew --prefix graphviz)/lib/" \
pygraphviz
```

### Usage

Expand Down
4 changes: 2 additions & 2 deletions torchexplorer/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def watch(
log_freq: int = 500,
ignore_io_grad_classes: list[type] = [],
disable_inplace: bool = False,
bins: int = 20,
sample_n: Optional[int] = 100,
bins: int = 30,
sample_n: Optional[int] = 1000,
reject_outlier_proportion: float = 0.1,
time_log: tuple[str, Callable] = ('step', lambda module, step: step),
delay_log_multi_backward: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions torchexplorer/vega/mock_vega_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main():
model = torchvision.models.resnet18()
inplace_classes = [torchvision.models.resnet.BasicBlock]
structure_wrapper = api.watch(
model, log_freq=1, backend='none',
model, log_freq=1, backend='none', bins=20, sample_n=500,
ignore_io_grad_classes=inplace_classes, disable_inplace=True
)
X, y = torch.randn(5, 3, 32, 32), torch.randn(5, 1000)
Expand All @@ -65,7 +65,7 @@ def main():
optimizer = optim.Adam(model.parameters(), lr=1e-2)
loss_fn = torch.nn.MSELoss()

for step in range(2):
for step in range(10):
y_hat = model(X)
loss = loss_fn(y_hat, y)
loss.backward()
Expand Down
11 changes: 9 additions & 2 deletions torchexplorer/vega/vega_dataless.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

{
"$schema": "https://vega.github.io/schema/vega/v5.json",
"description": "A neural network visualizer.",
Expand Down Expand Up @@ -1448,10 +1449,15 @@
}
},
"transform": [
{
"type": "formula",
"as": "x",
"expr": "datum.x+1"
},
{
"type": "formula",
"as": "width",
"expr": "datum.width+1"
"expr": "datum.width-2"
},
{
"type": "formula",
Expand Down Expand Up @@ -1556,7 +1562,8 @@
"enter": {
"stroke": {"value": "#000"},
"strokeWidth": {"value": 2.5},
"strokeDash": {"value": [8,2]}
"strokeDash": {"value": [7,4]},
"strokeOpacity": {"value": 1.0}
},
"update": {
"x": {"signal": "datum.histogram_x + histogram_margin + histogram_inner_width"},
Expand Down

0 comments on commit 5e42e15

Please sign in to comment.