Skip to content

Commit

Permalink
feat: add image callback in sampling (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Aug 5, 2022
1 parent 4738bc0 commit d0e6e18
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 39 deletions.
18 changes: 18 additions & 0 deletions discoart/nn/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import torch
import random


def set_seed(seed: int) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True


def detach_gpu(val):
if isinstance(val, (int, float)):
return val
else:
return val.detach().cpu().item()
6 changes: 6 additions & 0 deletions discoart/nn/transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import torch
import torchvision.transforms as T

inv_normalize = T.Normalize(
mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711],
std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711],
)


def symmetry_transformation_fn(x, use_horizontal_symmetry, use_vertical_symmetry):
Expand Down
13 changes: 7 additions & 6 deletions discoart/persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _sample(
is_save_gif,
is_image_output,
is_display_step,
image_callback,
):
with threading.Lock():
is_sampling_done.clear()
Expand All @@ -54,13 +55,13 @@ def _sample(
if is_save_step:
if is_image_output:
if cur_t == -1:
c.save_uri_to_file(
os.path.join(output_dir, f'{_nb}-done-{k}.png')
)
f_name = os.path.join(output_dir, f'{_nb}-done-{k}.png')
else:
c.save_uri_to_file(
os.path.join(output_dir, f'{_nb}-step-{j}-{k}.png')
)
f_name = os.path.join(output_dir, f'{_nb}-step-{j}-{k}.png')
c.save_uri_to_file(f_name)

if callable(image_callback):
image_callback(f_name)

da[k].chunks.append(c)

Expand Down
48 changes: 15 additions & 33 deletions discoart/runner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import copy
import os.path
import random
import tempfile
import threading
from typing import Callable, Optional

import clip
import lpips
import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import wandb
from docarray import DocumentArray, Document
Expand All @@ -26,20 +24,18 @@
get_output_dir,
is_jupyter,
)
from .nn.helper import set_seed, detach_gpu
from .nn.losses import spherical_dist_loss, tv_loss, range_loss
from .nn.make_cutouts import MakeCutouts
from .nn.sec_diff import alpha_sigma_to_t
from .nn.transform import symmetry_transformation_fn
from .nn.transform import symmetry_transformation_fn, inv_normalize
from .persist import _sample_thread, _persist_thread, _save_progress_thread
from .prompt import PromptPlanner

inv_normalize = T.Normalize(
mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711],
std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711],
)


def do_run(args, models, device, events) -> 'DocumentArray':
def do_run(
args, models, device, events, image_callback: Optional[Callable[[str], None]] = None
) -> 'DocumentArray':
skip_event, stop_event = events

_is_jupyter = is_jupyter()
Expand Down Expand Up @@ -114,7 +110,7 @@ def do_run(args, models, device, events) -> 'DocumentArray':

init = None

_set_seed(args.seed)
set_seed(args.seed)
if args.init_image:
d = Document(uri=args.init_image).load_uri_to_image_tensor(side_x, side_y)
init = (
Expand Down Expand Up @@ -293,12 +289,12 @@ def cond_fn(x, t, **kwargs):
) # min=-0.02, min=-clamp_max,

traced_info = {
'losses/total': _detach(loss) + cut_losses,
'losses/tv': _detach(tv_losses),
'losses/range': _detach(range_losses),
'losses/sat': _detach(sat_losses),
'losses/init': _detach(init_losses),
'losses/cuts': _detach(cut_losses),
'losses/total': detach_gpu(loss) + cut_losses,
'losses/tv': detach_gpu(tv_losses),
'losses/range': detach_gpu(range_losses),
'losses/sat': detach_gpu(sat_losses),
'losses/init': detach_gpu(init_losses),
'losses/cuts': detach_gpu(cut_losses),
}

traced_info.update(
Expand Down Expand Up @@ -351,7 +347,7 @@ def cond_fn(x, t, **kwargs):

# set seed for each image in the batch
new_seed = org_seed + _nb
_set_seed(new_seed)
set_seed(new_seed)
args.seed = new_seed
if _is_jupyter:
redraw_widget(
Expand Down Expand Up @@ -441,6 +437,7 @@ def cond_fn(x, t, **kwargs):
args.gif_fps > 0,
args.image_output,
is_display_step,
image_callback,
)
)

Expand Down Expand Up @@ -497,18 +494,3 @@ def redraw_widget(_handlers, _redraw_fn, args, _nb):

_handlers.code.value = export_python(args)
_redraw_fn()


def _set_seed(seed: int) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True


def _detach(val):
if isinstance(val, (int, float)):
return val
else:
return val.detach().cpu().item()

0 comments on commit d0e6e18

Please sign in to comment.