Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Updrade onmt #6

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c0b47cc
fix: upgraded onmt_preprocess to onmt_build_vocab, added wrapper argu…
irinaespejo Apr 9, 2024
ea9fb84
chore: mypy, typing, isort...
irinaespejo Apr 9, 2024
528c5ab
chore: isort
irinaespejo Apr 9, 2024
56ac6ed
chore: add yaml to ignore mypy
irinaespejo Apr 9, 2024
b3122d6
tests: add matrix with python 3.11 to github wflows
irinaespejo Apr 9, 2024
70b0450
chore: more mypy
irinaespejo Apr 9, 2024
918f494
fix: add rxn-onmt-utils depedency without rxn-opennmt-py depedency
irinaespejo Apr 15, 2024
1d43e49
fix: added explicit dependecy on official OpenNMT-py
irinaespejo Apr 15, 2024
12554b4
remove github workflow checks in python 3.7
irinaespejo Apr 15, 2024
fc5dc17
fix: pass config file to onmt_train
irinaespejo Apr 16, 2024
61fc311
chore: removed debugging console
irinaespejo Apr 16, 2024
68626f9
chore: upgrade dependecy commit to rxn-onmt-utils + black, isort...
irinaespejo Apr 16, 2024
4a3f439
fix: adapt finetune + continue_train to same changes as in train
irinaespejo Apr 16, 2024
d362b3e
fix: share vocab = False to save tgt vocab file too
irinaespejo Apr 30, 2024
2369eaf
fix: warning hidden vs rnn size, priority hidden_size
irinaespejo Apr 30, 2024
f2e026d
chore: update commit rxn-onmt-utils depedency
irinaespejo Apr 30, 2024
d9ec801
fix: correct the commit to latest
irinaespejo Apr 30, 2024
519a7c8
fix: back to share_vocab=True
irinaespejo Apr 30, 2024
98b79a0
fix: use all corpus for vocab
irinaespejo May 2, 2024
12066d2
fix: added model_task argument for correct config in retro task + upd…
irinaespejo May 24, 2024
2d95793
chore: update dep rxn-onmt-utils
irinaespejo May 27, 2024
560a899
temptative update utils dependency
irinaespejo May 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,30 @@ jobs:
build:
runs-on: ubuntu-latest
name: Build the Sphinx docs
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v3
with:
python-version: 3.8
- name: Install package dependencies
run: pip install -e .[rdkit]
- name: Install sphinx dependencies
run: pip install -r docs/requirements.txt
- name: Make docs
working-directory: ./docs
run: make html
- name: Upload artifacts
uses: actions/upload-artifact@v3
with:
name: html-docs
path: docs/build/html/
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
if: github.ref == 'refs/heads/main'
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: docs/build/html

- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install package dependencies
run: pip install -e .[rdkit]
- name: Install sphinx dependencies
run: pip install -r docs/requirements.txt
- name: Make docs
working-directory: ./docs
run: make html
- name: Upload artifacts
uses: actions/upload-artifact@v3
with:
name: html-docs
path: docs/build/html/
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
if: github.ref == 'refs/heads/main'
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: docs/build/html
36 changes: 19 additions & 17 deletions .github/workflows/pypi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,28 @@ name: Build and publish rxn-onmt-models on PyPI
on:
push:
tags:
- 'v*'
- "v*"

jobs:
build-and-publish:
name: Build and publish rxn-onmt-models on PyPI
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@master
- name: Python setup 3.9
uses: actions/setup-python@v1
with:
python-version: 3.9
- name: Install build package (for packaging)
run: pip install --upgrade build
- name: Build dist
run: python -m build
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN }}
skip_existing: true
- uses: actions/checkout@master
- name: Python setup ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install build package (for packaging)
run: pip install --upgrade build
- name: Build dist
run: python -m build
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN }}
skip_existing: true
42 changes: 23 additions & 19 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,27 @@ jobs:
tests:
runs-on: ubuntu-latest
name: Style, mypy, pytest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.7
uses: actions/setup-python@v3
with:
python-version: 3.7
- name: Install Dependencies
run: pip install -e .[dev,rdkit]
- name: Check black
run: python -m black --check --diff --color .
- name: Check isort
run: python -m isort --check --diff .
- name: Check flake8
run: python -m flake8 .
- name: Check mypy (on the package)
run: python -m mypy --namespace-packages -p rxn.onmt_models
- name: Check mypy (on the tests)
run: python -m mypy tests
- name: Run pytests
run: python -m pytest
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: pip install -e .[dev,rdkit]
- name: Check black
run: python -m black --check --diff --color .
- name: Check isort
run: python -m isort --check --diff .
- name: Check flake8
run: python -m flake8 .
- name: Check mypy (on the package)
run: python -m mypy --namespace-packages -p rxn.onmt_models
- name: Check mypy (on the tests)
run: python -m mypy tests
- name: Run pytests
run: python -m pytest

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module = [
"numpy.*",
"pandas.*",
"pytest.*",
"yaml.*",
]
ignore_missing_imports = true

Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ install_requires =
attrs>=21.2.0
click>=8.0
rxn-chem-utils>=1.1.4
rxn-onmt-utils>=1.0.3
rxn-reaction-preprocessing>=2.0.2
rxn-utils>=1.1.9
rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@bb9a2168aeabbfc265c4da4c3ef118a510cf130e #rxn-onmt-utils without rxn-opennmt-py depedency
OpenNMT-py>=3.5.1 # official onmt

[options.packages.find]
where = src
Expand Down
6 changes: 4 additions & 2 deletions src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
default=100000,
help="Number of steps, including steps from the initial training run.",
)
@click.option("--model_task", type=str, required=True)
def main(
batch_size: int,
data_weights: Tuple[int, ...],
Expand All @@ -66,6 +67,7 @@ def main(
preprocess_dir: str,
train_from: Optional[str],
train_num_steps: int,
model_task: str,
) -> None:
"""Continue training for an OpenNMT model.

Expand Down Expand Up @@ -111,11 +113,11 @@ def main(
train_steps=train_num_steps,
no_gpu=no_gpu,
data_weights=data_weights,
model_task=model_task,
)

# Write config file
command_and_args = train_cmd.save_to_config_cmd(config_file)
run_command(command_and_args)
train_cmd.save_to_config_cmd(config_file)

# Actual training config file
command_and_args = train_cmd.execute_from_config_cmd(config_file)
Expand Down
8 changes: 5 additions & 3 deletions src/rxn/onmt_models/scripts/rxn_onmt_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
@click.option("--warmup_steps", default=defaults.WARMUP_STEPS)
@click.option("--report_every", default=1000)
@click.option("--save_checkpoint_steps", default=5000)
@click.option("--model_task", type=str, required=True)
def main(
batch_size: int,
data_weights: Tuple[int, ...],
Expand All @@ -69,6 +70,7 @@ def main(
warmup_steps: int,
report_every: int,
save_checkpoint_steps: int,
model_task: str,
) -> None:
"""Finetune an OpenNMT model."""

Expand Down Expand Up @@ -112,7 +114,7 @@ def main(
dropout=dropout,
keep_checkpoint=keep_checkpoint,
learning_rate=learning_rate,
rnn_size=rnn_size,
hidden_size=rnn_size,
save_model=model_files.model_prefix,
seed=seed,
train_from=train_from,
Expand All @@ -122,11 +124,11 @@ def main(
data_weights=data_weights,
report_every=report_every,
save_checkpoint_steps=save_checkpoint_steps,
model_task=model_task,
)

# Write config file
command_and_args = train_cmd.save_to_config_cmd(config_file)
run_command(command_and_args)
train_cmd.save_to_config_cmd(config_file)

# Actual training config file
command_and_args = train_cmd.execute_from_config_cmd(config_file)
Expand Down
117 changes: 104 additions & 13 deletions src/rxn/onmt_models/scripts/rxn_onmt_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import random
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import click
import yaml
from rxn.chemutils.tokenization import ensure_tokenized_file
from rxn.onmt_utils import __version__ as onmt_utils_version
from rxn.onmt_utils.train_command import preprocessed_id_names
Expand Down Expand Up @@ -51,6 +52,89 @@ def determine_train_dataset(
return src, tgt


def get_build_vocab_config_file(
train_srcs: List[PathLike],
train_tgts: List[PathLike],
valid_src: PathLike,
valid_tgt: PathLike,
save_data: Path,
share_vocab: bool = True,
overwrite: bool = True,
src_seq_length: int = 3000,
tgt_seq_length: int = 3000,
src_vocab_size: int = 3000,
tgt_vocab_size: int = 3000,
) -> Path:
"""Wrapper function of the legacy cli `onmt_preprocessed` arguments.
The goal is to make them compatible with ONMT v.3.5.1 cli `onmt_build_vocab`.
The function takes the arguments of former onmt_preprocessed cli and dumps
them into a `config.yaml` file with a specific structure compatible with `onmt_build_vocab`.
The upgraded `onmt_build_vocab` takes them as `onmt_build_vocab -config config.yaml`.

Args:
train_srcs (List[PathLike]): List of train source data files
train_tgts (List[PathLike]): List of train target data files
valid_src (List[PathLike]): List of validation source data files
valid_tgt (List[PathLike]): List of validation target data files
save_data (PathLike): Save vocabulary data directory
share_vocab (bool, optional): Share vocab. Defaults to True.
overwrite (bool, optional): Overwrite output directory. Defaults to True.
src_seq_length (int, optional): src_seq_length. Defaults to 3000.
tgt_seq_length (int, optional): tgt_seq_length. Defaults to 3000.
src_vocab_size (int, optional): src_vocab_size. Defaults to 3000.
tgt_vocab_size (int, optional): tgt_vocab_size. Defaults to 3000.

Returns:
PathLike: Path of the config.yaml which is in directory `save_data`
"""

# Build dictionary with build vocab config content
# See structure https://opennmt.net/OpenNMT-py/quickstart.html (Step 1: Prepare the data)
build_vocab_config: Dict[str, Any] = {}

# Arguments save data
build_vocab_config["save_data"] = str(save_data.parent)
build_vocab_config["src_vocab"] = str(
save_data.parent / (save_data.name + ".vocab.src")
)
build_vocab_config["tgt_vocab"] = str(
save_data.parent / (save_data.name + ".vocab.tgt")
)

# Other arguments
build_vocab_config["overwrite"] = str(overwrite)
build_vocab_config["share_vocab"] = str(share_vocab)
build_vocab_config["src_seq_length"] = str(src_seq_length)
build_vocab_config["tgt_seq_length"] = str(tgt_seq_length)
build_vocab_config["src_vocab_size"] = str(src_vocab_size)
build_vocab_config["tgt_vocab_size"] = str(tgt_vocab_size)

# Arguments data paths (train)
build_vocab_config["data"] = {}
# TODO: raise error if lengths: train_srcs, train_tgts, valid_src, valid_tgt are different
number_corpus = len(train_srcs)
for i in range(number_corpus):
build_vocab_config["data"][f"corpus_{i+1}"] = {
"path_src": str(train_srcs[i]),
"path_tgt": str(train_tgts[i]),
}

# Arguments data paths (valid)
build_vocab_config["data"]["valid"] = {
"path_src": str(valid_src),
"path_tgt": str(valid_tgt),
}

# Path to same yaml file
config_file_path = save_data.parent / (save_data.name + "_build_vocab_config.yaml")

# Save file that will be -config argument of onmt_build_vocab
with open(config_file_path, "w+") as file:
yaml.dump(build_vocab_config, file)

return config_file_path


@click.command()
@click.option(
"--input_dir",
Expand Down Expand Up @@ -180,21 +264,28 @@ def main(
valid_src = ensure_tokenized_file(valid_src)
valid_tgt = ensure_tokenized_file(valid_tgt)

# Create config file for onmt_build_vocab for OpenNMT v.3.5.1
# Dump train_srcs, train_tgts, valid_src, valid_tgt etc and return path
config_file_path = get_build_vocab_config_file(
train_srcs=train_srcs,
train_tgts=train_tgts,
valid_src=valid_src,
valid_tgt=valid_tgt,
save_data=onmt_preprocessed_files.preprocess_prefix,
share_vocab=True,
overwrite=True,
src_seq_length=3000,
tgt_seq_length=3000,
src_vocab_size=3000,
tgt_vocab_size=3000,
)

# yapf: disable
command_and_args = [
str(e) for e in [
'onmt_preprocess',
'-train_src', *train_srcs,
'-train_tgt', *train_tgts,
'-valid_src', valid_src,
'-valid_tgt', valid_tgt,
'-save_data', onmt_preprocessed_files.preprocess_prefix,
'-src_seq_length', 3000,
'-tgt_seq_length', 3000,
'-src_vocab_size', 3000,
'-tgt_vocab_size', 3000,
'-share_vocab',
'-overwrite',
'onmt_build_vocab',
'-config', config_file_path,
'-n_sample', -1,
]
]
# yapf: enable
Expand Down
Loading
Loading