Skip to content

Commit

Permalink
visualization.py: enable lables in resulting svg file
Browse files Browse the repository at this point in the history
  • Loading branch information
cknoll committed Dec 12, 2023
1 parent 49f35bf commit ab3f1f9
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions src/pyerk/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Union, List, Tuple, Optional
import os
import urllib
from rdflib import Literal

import networkx as nx
import nxv # for graphviz visualization of networkx graphs
Expand Down Expand Up @@ -406,7 +407,8 @@ def create_complete_graph(
pass
else:
node = create_node(item, url_template)
G.add_node(node, label=item.short_key, color=get_color_for_item(item))
label_str = f"{item.short_key}\n{item.R1__has_label}"
G.add_node(node, label=label_str, color=get_color_for_item(item))
added_items_nodes[item_uri] = node

# iterate over relation edges
Expand Down Expand Up @@ -477,6 +479,24 @@ def render_graph_to_dot(G: nx.DiGraph) -> str:
return dot_data


def svg_replace(raw_svg_data: str, REPLACEMENTS: dict) -> str:
assert isinstance(raw_svg_data, str)

# prevent some latex stuff to interfere with the handing of the `REPLACEMENTS`
# TODO: handle the whole problme more elegantly

latex_replacements = [(r"\dot{x}", "__LATEX1__")]
for orig, subs in latex_replacements:
raw_svg_data = raw_svg_data.replace(orig, subs)

svg_data1: str = raw_svg_data.format(**REPLACEMENTS)

for orig, subs in latex_replacements:
svg_data1 = svg_data1.replace(subs, orig)

return svg_data1


def visualize_entity(uri: str, url_template="", write_tmp_files: bool = False) -> str:
"""
Expand All @@ -502,8 +522,8 @@ def visualize_entity(uri: str, url_template="", write_tmp_files: bool = False) -

# noinspection PyUnresolvedReferences,PyProtectedMember
raw_svg_data = nxv._graphviz.run(dot_data, algorithm="dot", format="svg", graphviz_bin=None)

svg_data1: str = raw_svg_data.decode("utf8").format(**REPLACEMENTS)
raw_svg_data = raw_svg_data.decode("utf8")
svg_data1 = svg_replace(raw_svg_data, REPLACEMENTS)

if write_tmp_files:
# for debugging
Expand All @@ -521,6 +541,12 @@ def visualize_entity(uri: str, url_template="", write_tmp_files: bool = False) -
return svg_data1


def get_label(entity):
res = entity.get("label", "undefined label")
if isinstance(res, Literal):
return res.value
return res

def visualize_all_entities(url_template="", write_tmp_files: bool = False) -> str:
G = create_complete_graph(url_template)

Expand All @@ -542,7 +568,7 @@ def visualize_all_entities(url_template="", write_tmp_files: bool = False) -> st
"width": 1.3,
"fontsize": 10,
"color": d.get("color", "black"),
"label": d.get("label", "undefined label"),
"label": get_label(d),
"shape": d.get("shape", "circle"), # see also AbstractNode.shape
},
# u: node1, v: node1, d: its attribute dict
Expand All @@ -562,7 +588,7 @@ def visualize_all_entities(url_template="", write_tmp_files: bool = False) -> st
"color": d.get("color", "black"),
"width": 0.3,
"fontsize": 2,
"label": d.get("label", "undefined label"),
"label": get_label(d),
"fillcolor": "#45454533",
},
edge=lambda u, v, d: {
Expand All @@ -581,7 +607,7 @@ def visualize_all_entities(url_template="", write_tmp_files: bool = False) -> st
dot_data = raw_dot_data
# noinspection PyUnresolvedReferences,PyProtectedMember
raw_svg_data = nxv._graphviz.run(dot_data, algorithm="dot", format="svg", graphviz_bin=None)
svg_data1: str = raw_svg_data.decode("utf8").format(**REPLACEMENTS)
svg_data1 = svg_replace(raw_svg_data.decode("utf8"), REPLACEMENTS)

if write_tmp_files:
# for debugging
Expand Down

0 comments on commit ab3f1f9

Please sign in to comment.