Skip to content

Commit

Permalink
Better behavior if something goes wrong
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Pfrommer committed Jan 21, 2024
1 parent aa1231e commit db69118
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 3 additions & 0 deletions torchexplorer/hook/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def record_sizes(tensors: tuple[OTensor, ...], tensor_trackers: list[SizeTracker
assert stored_shape is not None
if len(shape) != len(stored_shape):
tensor_trackers[i].size = None
continue

# TODO: with graphviz caching, this actually doesn't need to run
# since only the first pass sizes are stored.
Expand Down Expand Up @@ -336,6 +337,8 @@ def process_tensor(tensor: Tensor) -> Optional[GradFn]:
return tuple(process_tensor(tensor) for tensor in utils.iter_not_none(tensors))

def _get_next_gradfns(tensors: tuple[OTensor, ...]) -> tuple[Optional[GradFn], ...]:
if len(tensors) == 0:
return tuple()
# Hacky workaround, couldn't figure out import
# Check if all gradfns are the same 'BackwardHookFunctionBackward'
backhook_class = 'BackwardHookFunctionBackward'
Expand Down
7 changes: 5 additions & 2 deletions torchexplorer/render/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,11 @@ def _translate_inner_layouts(layout: NodeLayout) -> None:
if is_input_node(l.display_name):
input_centers.append(_center(l))

center = np.mean(np.array(input_centers), axis=0)
trans = utils.list_add(target_input_pos, [-center[0], -center[1]])
if len(input_centers) > 0:
center = np.mean(np.array(input_centers), axis=0)
trans = utils.list_add(target_input_pos, [-center[0], -center[1]])
else:
trans = target_input_pos

def apply_translation(l: Union[NodeLayout, TooltipLayout]):
l.bottom_left_corner = utils.list_add(l.bottom_left_corner, trans)
Expand Down

0 comments on commit db69118

Please sign in to comment.