Skip to content

Commit

Permalink
Save each evaluation epoch (#721)
Browse files Browse the repository at this point in the history
* save_checkpoint on each_evaluation_epoch

* add tokenizer_kwargs to cfgs
also fix temperature in cfgs when do_sample is false

* add missing cfg keys

* no overwrite

* Add to local download

* fix test

* fix for deepspeed

* if not path:
  • Loading branch information
pascal-pfeiffer committed May 24, 2024
1 parent 4e46ed5 commit 7ddca05
Show file tree
Hide file tree
Showing 19 changed files with 120 additions and 52 deletions.
2 changes: 1 addition & 1 deletion documentation/docs/get-started/llm-studio-performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ prediction:
num_history: 4
repetition_penalty: 1.2
stop_tokens: ''
temperature: 0.3
temperature: 0.0
top_k: 0
top_p: 1.0
problem_type: text_causal_language_modeling
Expand Down
5 changes: 5 additions & 0 deletions documentation/docs/tooltips/experiments/_save-checkpoint.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,9 @@ When set to **Best** it saves the model weights for the epoch exhibiting the bes
- The default goal should be to attempt to tune models so that the last epoch is the best epoch.
- Suppose an evident decline for later epochs is observed in logging. In that case, it is usually better to adjust hyperparameters, such as reducing the number of epochs or increasing regularization, instead of turning this setting on.

When set to **Each evaluation epoch** it will save the model weights for each evaluation epoch.
- This can be useful for debugging and experimenting, but will consume more disk space.
- Models uploaded to Hugging Face Hub will only contain the last checkpoint.
- Local downloads will contain all checkpoints.

When set to **Disable** it will not save the checkpoint at all. This can be useful for debugging and experimenting in order to save disk space, but will disable certain functionalities like chatting or pushing to HF.
10 changes: 9 additions & 1 deletion examples/example_oasst2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,21 @@ dataset:
validation_strategy: automatic
environment:
compile_model: false
deepspeed_allgather_bucket_size: 1000000
deepspeed_method: ZeRO2
deepspeed_reduce_bucket_size: 1000000
deepspeed_stage3_param_persistence_threshold: 1000000
deepspeed_stage3_prefetch_bucket_size: 1000000
find_unused_parameters: false
gpus:
- '0'
huggingface_branch: main
mixed_precision: true
mixed_precision_dtype: bfloat16
number_of_workers: 8
seed: -1
trust_remote_code: true
use_deepspeed: false
experiment_name: example_oasst2
llm_backbone: h2oai/h2o-danube2-1.8b-base
logging:
Expand All @@ -63,7 +70,7 @@ prediction:
num_history: 4
repetition_penalty: 1.2
stop_tokens: ''
temperature: 0.3
temperature: 0.0
top_k: 0
top_p: 1.0
problem_type: text_causal_language_modeling
Expand All @@ -73,6 +80,7 @@ tokenizer:
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 2
differential_learning_rate: 1.0e-05
Expand Down
14 changes: 14 additions & 0 deletions llm_studio/app_utils/sections/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,7 @@ async def experiment_download_model(q: Q):
# Add all files that were created after the model was saved.
# This is useful for potential changes/different
# naming conventions across different backbones.
# Also adds newly generated safetensor files.
for file in os.listdir(checkpoint_path):
file_path = os.path.join(checkpoint_path, file)
if (
Expand All @@ -1718,6 +1719,19 @@ async def experiment_download_model(q: Q):
f"Added {file_path} to zip file as it "
"was created when saving the model state."
)

# Add all files from subdirectories, which include the intermediate checkpoints
subdirectories = [
d
for d in os.listdir(checkpoint_path)
if os.path.isdir(os.path.join(checkpoint_path, d))
]
for subdirectory in subdirectories:
for file in os.listdir(os.path.join(checkpoint_path, subdirectory)):
file_path = os.path.join(checkpoint_path, subdirectory, file)
add_file_to_zip(zf=zf, path=file_path, folder=subdirectory)
paths_added.append(file_path)
logger.info(f"Added {file_path} to zip file.")
zf.close()

download_url = get_download_link(q, zip_path)
Expand Down
2 changes: 1 addition & 1 deletion llm_studio/app_utils/sections/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def home(q: Q) -> None:
("queued + running", num_running_queued),
("failed + stopped", num_failed_stopped),
],
pack=True, # type: ignore
pack=True,
),
)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __post_init__(self):
values=(
("last", "Last"),
("best", "Best"),
("each_evaluation_epoch", "Each evaluation epoch"),
("disable", "Disable"),
),
allow_custom=False,
Expand Down
63 changes: 31 additions & 32 deletions llm_studio/src/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,53 +86,52 @@ def check_disk_space(model: torch.nn.Module, path: str):


# TODO: currently not saving optimizer
def save_checkpoint(model: torch.nn.Module, path: str, cfg: Any):
def save_checkpoint(model: torch.nn.Module, path: str, cfg: Any) -> None:
"""Saves a model checkpoint if the path is provided.
Args:
model: model to save
path: path to save the checkpoint to
Returns:
Dictionary with all the keys to save
"""

if not path:
raise ValueError(f"Path must be provided. Received {path}.")

if not os.path.exists(path):
os.makedirs(path)

if cfg.environment.use_deepspeed:
if path is not None:
# gather model params from all ranks when using Deepspeed
status = model.save_16bit_model(path, "checkpoint.pth") # type: ignore
if status:
if cfg.environment._local_rank == 0:
checkpoint = {
"model": torch.load(
os.path.join(path, "checkpoint.pth"), map_location="cpu"
)
}
else:
logger.warning(
"deepspeed.save_16bit_model didn't save the model, since"
" stage3_gather_16bit_weights_on_model_save=False."
" Saving the full checkpoint instead"
)
model.save_checkpoint( # type: ignore
# gather model params from all ranks when using Deepspeed
status = model.save_16bit_model(path, "checkpoint.pth")
if status:
if cfg.environment._local_rank == 0:
checkpoint = {
"model": torch.load(
os.path.join(path, "checkpoint.pth"), map_location="cpu"
)
}
else:
logger.warning(
"deepspeed.save_16bit_model didn't save the model, since"
" stage3_gather_16bit_weights_on_model_save=False."
" Saving the full checkpoint instead"
)
model.save_checkpoint(os.path.join(path, "ds_checkpoint"))
if cfg.environment._local_rank == 0:
# load to cpu
state_dict = get_fp32_state_dict_from_zero_checkpoint(
os.path.join(path, "ds_checkpoint")
)
if cfg.environment._local_rank == 0:
# load to cpu
state_dict = get_fp32_state_dict_from_zero_checkpoint(
os.path.join(path, "ds_checkpoint")
)
# save as normal checkpoint that can be loaded by `load_state_dict`
checkpoint = {"model": state_dict}
torch.save(checkpoint, os.path.join(path, "checkpoint.pth"))
shutil.rmtree(os.path.join(path, "ds_checkpoint"))
# save as normal checkpoint that can be loaded by `load_state_dict`
checkpoint = {"model": state_dict}
torch.save(checkpoint, os.path.join(path, "checkpoint.pth"))
shutil.rmtree(os.path.join(path, "ds_checkpoint"))

else:
if cfg.environment._local_rank == 0:
model = unwrap_model(model)
checkpoint = {"model": model.state_dict()}
if path is not None:
torch.save(checkpoint, os.path.join(path, "checkpoint.pth"))
torch.save(checkpoint, os.path.join(path, "checkpoint.pth"))

if (
cfg.environment._local_rank == 0
Expand Down
28 changes: 26 additions & 2 deletions llm_studio/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,21 @@ def kill_ddp_processes(kill_parent=True) -> None:
current_process.kill()


def add_file_to_zip(zf: zipfile.ZipFile, path: str) -> None:
def add_file_to_zip(zf: zipfile.ZipFile, path: str, folder=None) -> None:
"""Adds a file to the existing zip. Does nothing if file does not exist.
Args:
zf: zipfile object to add to
path: path to the file to add
folder: folder in the zip to add the file to
"""

try:
zf.write(path, os.path.basename(path))
if folder is None:
zip_path = os.path.basename(path)
else:
zip_path = os.path.join(folder, os.path.basename(path))
zf.write(path, zip_path)
except Exception:
logger.warning(f"File {path} could not be added to zip.")

Expand Down Expand Up @@ -165,3 +170,22 @@ def __exit__(self, exc_type, exc_val, exc_tb):
setattr(self.obj, self.attribute, self.original_value)
else:
delattr(self.obj, self.attribute)


def create_symlinks_in_parent_folder(directory):
"""For each file in a folder, create a symbolic link to that in the parent folder"""

if not os.path.exists(directory):
raise FileNotFoundError(f"Directory {directory} does not exist.")

parent_directory = os.path.dirname(directory)
files = [
f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))
]

for file in files:
src = os.path.join(directory, file)
dst = os.path.join(parent_directory, file)
if os.path.exists(dst):
os.remove(dst)
os.symlink(src, dst)
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ tokenizer:
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 2
differential_learning_rate: 1.0e-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ tokenizer:
max_length_answer: 16
max_length_prompt: 16
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 6
differential_learning_rate: 1.0e-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ prediction:
num_history: 4
repetition_penalty: 1.2
stop_tokens: ''
temperature: 0.3
temperature: 0.0
top_k: 0
top_p: 1.0
problem_type: text_causal_language_modeling
Expand All @@ -74,6 +74,7 @@ tokenizer:
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 2
differential_learning_rate: 1.0e-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ prediction:
num_history: 4
repetition_penalty: 1.2
stop_tokens: ''
temperature: 0.3
temperature: 0.0
top_k: 0
top_p: 1.0
problem_type: text_causal_language_modeling
Expand All @@ -73,6 +73,7 @@ tokenizer:
max_length_answer: 16
max_length_prompt: 16
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 8
differential_learning_rate: 1.0e-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ tokenizer:
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 2
differential_learning_rate: 1.0e-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ tokenizer:
max_length_answer: 16
max_length_prompt: 16
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 2
differential_learning_rate: 1.0e-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ prediction:
num_history: 4
repetition_penalty: 1.2
stop_tokens: ''
temperature: 0.3
temperature: 0.0
top_k: 0
top_p: 1.0
problem_type: text_sequence_to_sequence_modeling
Expand All @@ -73,6 +73,7 @@ tokenizer:
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 2
differential_learning_rate: 1.0e-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ prediction:
num_history: 4
repetition_penalty: 1.2
stop_tokens: ''
temperature: 0.3
temperature: 0.0
top_k: 0
top_p: 1.0
problem_type: text_sequence_to_sequence_modeling
Expand All @@ -73,6 +73,7 @@ tokenizer:
max_length_answer: 16
max_length_prompt: 16
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 2
differential_learning_rate: 1.0e-05
Expand Down
3 changes: 2 additions & 1 deletion tests/src/test_data/cfg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ prediction:
num_beams: 2
repetition_penalty: 1.2
stop_tokens: ""
temperature: 0.3
temperature: 0.0
problem_type: text_causal_language_modeling
tokenizer:
max_length: 144
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
batch_size: 3
epochs: 0
Expand Down
2 changes: 1 addition & 1 deletion tests/src/utils/test_load_yaml_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_load_config_yaml():
assert cfg.prediction.num_beams == 2
assert cfg.prediction.repetition_penalty == 1.2
assert cfg.prediction.stop_tokens == ""
assert cfg.prediction.temperature == 0.3
assert cfg.prediction.temperature == 0.0

assert cfg.tokenizer.max_length == 144
assert cfg.tokenizer.max_length_answer == 256
Expand Down

0 comments on commit 7ddca05

Please sign in to comment.