Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support configuring gradio startup parameters #112

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ You can run DragGAN Gradio demo as well:
```sh
python visualizer_drag_gradio.py
```
If you want to run DragGAN Gradio demo with given server name, for example host 0.0.0.0 to accessible on local network:
```sh
python visualizer_drag_gradio.py --host=0.0.0.0
```
If you want to run DragGAN Gradio demo with given server port, for example port 8888:
```sh
python visualizer_drag_gradio.py --port=8888
```

## Acknowledgement

Expand Down
12 changes: 11 additions & 1 deletion visualizer_drag_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
get_latest_points_pair, get_valid_mask,
on_change_single_global_state)
from viz.renderer import Renderer, add_watermark_np
import inspect

parser = ArgumentParser()
parser.add_argument('--share', action='store_true',default='True')
parser.add_argument('--cache-dir', type=str, default='./checkpoints')
parser.add_argument('--host', type=str,
help="launch gradio with given server name", default=None)
parser.add_argument('--port', type=int,
help="launch gradio with given server port", default=None)
args = parser.parse_args()

cache_dir = args.cache_dir
Expand Down Expand Up @@ -861,6 +866,11 @@ def on_click_show_mask(global_state, show_mask):
outputs=[global_state, form_image],
)

sig = inspect.signature(app.launch)
params = sig.parameters
def_server_name = params["server_name"].default
def_server_port = params["server_port"].default

gr.close_all()
app.queue(concurrency_count=3, max_size=20)
app.launch(share=args.share)
app.launch(share=args.share, server_name=args.host if args.host is not None else def_server_name, server_port=args.port if args.port is not None else def_server_port)