Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] expose start_iteration to dump/save/lgb.model.dt.tree #6398

Merged
merged 30 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
839eebf
Expose start_iteration to dump()
mayer79 Apr 1, 2024
e340f34
Expose start_iteration to save()
mayer79 Apr 1, 2024
775c32f
Expose start_iteration to Booster$save_model_to_string()
mayer79 Apr 1, 2024
ce9ceb9
Merge branch 'master' into r-batchwise-imp
mayer79 Apr 1, 2024
77d63e5
add missing argument to LGBM_BoosterSaveModel_R()
mayer79 Apr 2, 2024
ef107ef
Merge branch 'master' into r-batchwise-imp
mayer79 Apr 20, 2024
2440c54
add unit tests
mayer79 Apr 21, 2024
740f7fb
forgot to run lintr on unit tests
mayer79 Apr 21, 2024
f67f7f8
add unit test for lgb.save()
mayer79 Apr 21, 2024
3d516a6
superfluous empty space
mayer79 Apr 23, 2024
0a75f3f
better wording of docstring
mayer79 Apr 23, 2024
88bdef1
Use function to generate model in tests
mayer79 Apr 23, 2024
1180462
rename get_n_trees() to .get_n_trees()
mayer79 Apr 23, 2024
37c7988
Merge branch 'r-batchwise-imp' of https://github.com/mayer79/LightGBM…
mayer79 Apr 23, 2024
bb1346e
More strict test
mayer79 Apr 23, 2024
4d56543
Expose .get_test_model() to tests
mayer79 Apr 23, 2024
147227f
Merge branch 'r-batchwise-imp' of https://github.com/mayer79/LightGBM…
mayer79 Apr 23, 2024
a5f132d
Turn to 1 based start_iterations and move argument to the end
mayer79 Apr 23, 2024
0224c30
fix argument order in C++ API of R
mayer79 Apr 23, 2024
308d7ed
Update unit tests
mayer79 Apr 23, 2024
bdcc11d
Merge branch 'master' into r-batchwise-imp
mayer79 May 9, 2024
ed60761
More unit tests
mayer79 May 10, 2024
a149d6d
fix LGBM_BoosterSaveModelToString_R()
mayer79 May 10, 2024
f3c3aaf
Linter
mayer79 May 10, 2024
7e720d9
Unit tests for save_model_to_string() like for save_model()
mayer79 May 11, 2024
bff3e68
Merge branch 'master' into r-batchwise-imp
mayer79 May 11, 2024
ed01a84
Apply suggestions from code review
mayer79 May 12, 2024
2958cfa
Update docstring of lgb.model.dt.tree.R
mayer79 May 12, 2024
42e8204
roxygenize
mayer79 May 12, 2024
48e468c
Clarify start_iteration also for lgb.save() and lgb.dump()
mayer79 May 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 33 additions & 9 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,12 @@ Booster <- R6::R6Class(
},

# Save model
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
save_model = function(
filename
, num_iteration = NULL
, start_iteration = 0L
mayer79 marked this conversation as resolved.
Show resolved Hide resolved
, feature_importance_type = 0L
) {

self$restore_handle()

Expand All @@ -430,14 +435,20 @@ Booster <- R6::R6Class(
LGBM_BoosterSaveModel_R
, private$handle
, as.integer(num_iteration)
, as.integer(start_iteration)
, as.integer(feature_importance_type)
, filename
)

return(invisible(self))
},

save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {
save_model_to_string = function(
num_iteration = NULL
, start_iteration = 0L
, feature_importance_type = 0L
, as_char = TRUE
) {

self$restore_handle()

Expand All @@ -449,6 +460,7 @@ Booster <- R6::R6Class(
LGBM_BoosterSaveModelToString_R
, private$handle
, as.integer(num_iteration)
, as.integer(start_iteration)
, as.integer(feature_importance_type)
)

Expand All @@ -461,7 +473,9 @@ Booster <- R6::R6Class(
},

# Dump model in memory
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
dump_model = function(
num_iteration = NULL, start_iteration = 0L, feature_importance_type = 0L
) {

self$restore_handle()

Expand All @@ -473,6 +487,7 @@ Booster <- R6::R6Class(
LGBM_BoosterDumpModel_R
, private$handle
, as.integer(num_iteration)
, as.integer(start_iteration)
, as.integer(feature_importance_type)
)

Expand Down Expand Up @@ -1288,8 +1303,9 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' @title Save LightGBM model
#' @description Save LightGBM model
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param filename Saved filename
#' @param num_iteration Number of iterations to save, NULL or <= 0 means use best iteration
#' @param start_iteration First iteration to save. Default is 0, i.e., start at first
mayer79 marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @return lgb.Booster
#'
Expand Down Expand Up @@ -1322,7 +1338,9 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' lgb.save(model, tempfile(fileext = ".txt"))
#' }
#' @export
lgb.save <- function(booster, filename, num_iteration = NULL) {
lgb.save <- function(
booster, filename, num_iteration = NULL, start_iteration = 0L
) {

if (!.is_Booster(x = booster)) {
stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
Expand All @@ -1338,6 +1356,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
invisible(booster$save_model(
filename = filename
, num_iteration = num_iteration
, start_iteration = start_iteration
))
)

Expand All @@ -1347,7 +1366,8 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param num_iteration Number of iterations to be dumped. NULL or <= 0 means use best iteration
#' @param start_iteration Start index of iteration. Default is 0, i.e., start at the first iteration
#'
#' @return json format of model
#'
Expand Down Expand Up @@ -1380,14 +1400,18 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' json_model <- lgb.dump(model)
#' }
#' @export
lgb.dump <- function(booster, num_iteration = NULL) {
lgb.dump <- function(booster, num_iteration = NULL, start_iteration = 0L) {

if (!.is_Booster(x = booster)) {
stop("lgb.dump: booster should be an ", sQuote("lgb.Booster"))
}

# Return booster at requested iteration
return(booster$dump_model(num_iteration = num_iteration))
return(
booster$dump_model(
num_iteration = num_iteration, start_iteration = start_iteration
mayer79 marked this conversation as resolved.
Show resolved Hide resolved
)
)

}

Expand Down
19 changes: 12 additions & 7 deletions R-package/R/lgb.model.dt.tree.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#' @name lgb.model.dt.tree
#' @title Parse a LightGBM model json dump
#' @description Parse a LightGBM model json dump into a \code{data.table} structure.
#' @param model object of class \code{lgb.Booster}
#' @param num_iteration number of iterations you want to predict with. NULL or
#' <= 0 means use best iteration
#' @param model object of class \code{lgb.Booster}.
#' @param num_iteration Number of iterations to be parsed. NULL or <= 0 means use best iteration.
mayer79 marked this conversation as resolved.
Show resolved Hide resolved
#' @param start_iteration Start index of the iteration (default 0).
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
#'
Expand Down Expand Up @@ -51,9 +51,15 @@
#' @importFrom data.table := rbindlist
#' @importFrom jsonlite fromJSON
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {

json_model <- lgb.dump(booster = model, num_iteration = num_iteration)
lgb.model.dt.tree <- function(
model, num_iteration = NULL, start_iteration = 0L
) {

json_model <- lgb.dump(
booster = model
, num_iteration = num_iteration
, start_iteration = start_iteration
)

parsed_json_model <- jsonlite::fromJSON(
txt = json_model
Expand Down Expand Up @@ -84,7 +90,6 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
tree_dt[, split_feature := feature_names]

return(tree_dt)

}


Expand Down
6 changes: 4 additions & 2 deletions R-package/man/lgb.dump.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions R-package/man/lgb.model.dt.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions R-package/man/lgb.save.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 12 additions & 7 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,29 +1092,32 @@ SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,

SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration,
SEXP start_iteration,
SEXP feature_importance_type,
SEXP filename) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
UNPROTECT(1);
return R_NilValue;
R_API_END();
}

SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration,
SEXP start_iteration,
SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
// if the model string was larger than the initial buffer, call the function again, writing directly to the R object
if (out_len > buf_len) {
Expand All @@ -1129,6 +1132,7 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,

SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration,
SEXP start_iteration,
SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
Expand All @@ -1137,13 +1141,14 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle,
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
}
model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
Expand Down Expand Up @@ -1261,9 +1266,9 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterPredictForMatSingleRow_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R , 9},
{"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8},
{"LGBM_BoosterPredictForMatSingleRowFast_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R , 3},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 4},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 4},
{"LGBM_NullBoosterHandleError_R" , (DL_FUNC) &LGBM_NullBoosterHandleError_R , 0},
{"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0},
{"LGBM_GetMaxThreads_R" , (DL_FUNC) &LGBM_GetMaxThreads_R , 1},
Expand Down
6 changes: 6 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -807,13 +807,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMatSingleRowFast_R(
* \brief save model into file
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param start_iteration Starting iteration
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param filename file name
* \return R NULL value
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
SEXP handle,
SEXP num_iteration,
SEXP start_iteration,
SEXP feature_importance_type,
SEXP filename
);
Expand All @@ -822,25 +824,29 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
* \brief create string containing model
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param start_iteration Starting iteration
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \return R character vector (length=1) with model string
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
SEXP handle,
SEXP num_iteration,
SEXP start_iteration,
SEXP feature_importance_type
);

/*!
* \brief dump model to JSON
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param start_iteration Index of starting iteration
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \return R character vector (length=1) with model JSON
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
SEXP handle,
SEXP num_iteration,
SEXP start_iteration,
SEXP feature_importance_type
);

Expand Down