Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parser for tb.dat #115

Open
wants to merge 5 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/silicon/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
build
build_tb
model_hr.dat
40 changes: 40 additions & 0 deletions examples/silicon/read_tb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python
# Construct a Model from wannier90 tb.dat file.
import os
import shutil
import subprocess
import numpy as np
import tbmodels as tb
import matplotlib.pyplot as plt

if __name__ == "__main__":
WANNIER90_COMMAND = os.path.expanduser("~/git/wannier90/wannier90.x")
BUILD_DIR = "./build_tb"

if not os.path.exists(BUILD_DIR):
shutil.copytree("./input", BUILD_DIR)
subprocess.call([WANNIER90_COMMAND, "silicon"], cwd=BUILD_DIR)

model = tb.Model.from_wannier_tb_files(
tb_file=f"{BUILD_DIR}/silicon_tb.dat",
wsvec_file=f"{BUILD_DIR}/silicon_wsvec.dat",
)
print(model)

# Compute band structure along an arbitrary kpath
theta = 37 / 180 * np.pi
phi = 43 / 180 * np.pi
rlist = np.linspace(0, 2, 20)
klist = [
[
r * np.sin(theta) * np.cos(phi),
r * np.sin(theta) * np.sin(phi),
r * np.cos(theta),
]
for r in rlist
]

eigvals = model.eigenval(klist)

plt.plot(eigvals)
plt.show()
139 changes: 139 additions & 0 deletions src/tbmodels/_tb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,61 @@ def _mat_to_hr(R, mat):
)
return lines

@staticmethod
def _read_tb(iterator, ignore_orbital_order=False):
"""Reads the content of a seedname_tb.dat file"""
next(iterator) # skip comment line

lattice = np.zeros((3, 3))
for i in range(3):
lattice[i, :] = np.fromstring(next(iterator), sep=" ")

num_wann = int(next(iterator))

nrpts = int(next(iterator))

# degeneracy of Wigner-Seitz grid points, 15 entries per line
deg_pts = []
# order in zip important because else the next data element is consumed
for _, line in zip(range(int(np.ceil(nrpts / 15))), iterator):
deg_pts.extend(int(x) for x in line.split())
assert len(deg_pts) == nrpts

# <0n|H|Rm>
hop_list = []
for ir in range(nrpts):
next(iterator) # skip empty
r_vec = [int(_) for _ in next(iterator).strip().split()]
for j in range(num_wann):
for i in range(num_wann):
line = next(iterator).strip().split()
iw, jw = (int(_) for _ in line[:2])
if not ignore_orbital_order and (iw != i + 1 or jw != j + 1):
raise ValueError(
f"Inconsistent orbital numbers in line '{line}'"
)
ham = (float(line[2]) + 1j * float(line[3])) / deg_pts[ir]
hop_list.append([ham, i, j, r_vec])

# <0n|r|Rm>
r_list = []
for ir in range(nrpts):
next(iterator) # skip empty
r_vec = [int(_) for _ in next(iterator).strip().split()]
for j in range(num_wann):
for i in range(num_wann):
line = next(iterator).strip().split()
iw, jw = (int(_) for _ in line[:2])
if not ignore_orbital_order and (iw != i + 1 or jw != j + 1):
raise ValueError(
f"Inconsistent orbital numbers in line '{line}'"
)
r = np.array([float(_) for _ in line[2:]])
r = r[::2] + 1j * r[1::2]
r_list.append([r, i, j, r_vec])

return lattice, num_wann, nrpts, deg_pts, hop_list, r_list

@classmethod
def from_wannier_folder(
cls, folder: str = ".", prefix: str = "wannier", **kwargs
Expand Down Expand Up @@ -714,6 +769,90 @@ def remap_hoppings(hop_entries):

return cls.from_hop_list(size=num_wann, hop_list=hop_entries, **kwargs)

@classmethod # noqa: MC0001
def from_wannier_tb_files( # pylint: disable=too-many-locals
cls,
*,
tb_file: str,
wsvec_file: str,
**kwargs,
) -> Model:
"""
Create a :class:`.Model` instance from Wannier90 output files.

Parameters
----------
tb_file :
Path of the ``*_tb.dat`` file. Together with the
``*_wsvec.dat`` file, this determines the hopping terms.
wsvec_file :
Path of the ``*_wsvec.dat`` file. This file determines the
remapping of hopping terms when ``use_ws_distance`` is used
in the Wannier90 calculation.
kwargs :
:class:`.Model` keyword arguments.
"""

if "uc" in kwargs:
raise ValueError(
"Ambiguous unit cell: It can be given either via 'uc' or the 'tb_file' keywords, but not both."
)
if "pos" in kwargs:
raise ValueError(
"Ambiguous orbital positions: The positions can be given either via the 'pos' or the 'tb_file' keywords, but not both."
)

with open(tb_file, encoding="utf-8") as f:
lattice, num_wann, _, _, hop_list, r_list = cls._read_tb(f)

kwargs["uc"] = lattice

def get_centers(r_list: ty.List[ty.Any]) -> ty.List[npt.NDArray[np.float_]]:
centers = [np.zeros(3) for _ in range(num_wann)]
for r, i, j, r_vec in r_list:
if r_vec != [0, 0, 0]:
continue
if i != j:
continue
r = np.array(r)
if not np.allclose(np.abs(r.imag), 0):
raise ValueError(f"Center should be real: WF {i+1}, center = {r}")
centers[i] = r.real
return centers

pos_cartesian: ty.Union[
ty.List[npt.NDArray[np.float_]], npt.NDArray[np.float_]
] = get_centers(r_list)

kwargs["pos"] = la.solve(kwargs["uc"].T, np.array(pos_cartesian).T).T

# hop_entries = (hop for hop in hop_entries if abs(hop[0]) > h_cutoff)
hop_entries = hop_list

with open(wsvec_file, encoding="utf-8") as f:
wsvec_generator = cls._async_parse(cls._read_wsvec(f), chunksize=num_wann)

def remap_hoppings(hop_entries):
for t, orbital_1, orbital_2, R in hop_entries:
# Step _async_parse to where it accepts
# a new key.
# The _async_parse does not raise StopIteration
next(wsvec_generator) # pylint: disable=stop-iteration-return
T_list = wsvec_generator.send((orbital_1, orbital_2, tuple(R)))
N = len(T_list)
for T in T_list:
# not using numpy here increases performance
yield (
t / N,
orbital_1,
orbital_2,
tuple(r + t for r, t in zip(R, T)),
)

hop_entries = remap_hoppings(hop_entries)

return cls.from_hop_list(size=num_wann, hop_list=hop_entries, **kwargs)

@staticmethod
def _async_parse(iterator, chunksize=1):
"""
Expand Down
Binary file not shown.
Loading