Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] ensure use of interaction_constraints does not lead to features being ignored #6377

Merged
merged 22 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4541fed
Improve .check_interaction_constraints()
mayer79 Mar 20, 2024
7e2fcc2
fix remaining_indices
mayer79 Mar 20, 2024
490f61f
Replace is.list() by identical(class(.), "list")
mayer79 Mar 20, 2024
5bd528c
Merge branch 'master' into r-api-interaction-constraints
mayer79 Mar 21, 2024
ee26f69
Updated one unit test
mayer79 Mar 21, 2024
947a49c
Fix failing unit test
mayer79 Mar 25, 2024
6688745
Add unit test to check if skipped features in interaction constraints…
mayer79 Mar 25, 2024
61d689e
Fix enumeration in comments
mayer79 Mar 27, 2024
a26117d
Changed the unit test on too large feature index in interaction const…
mayer79 Mar 27, 2024
8bd1c49
More beautiful error message on bad features in interaction constraints
mayer79 Mar 27, 2024
0833b7e
Update error message in unit test on interaction constraints
mayer79 Mar 27, 2024
c7902e7
Replace paste(, collapse) by toString()
mayer79 Mar 27, 2024
c3485c1
stop expect_error() to recognize string as regular expression
mayer79 Mar 27, 2024
a00c5f8
Fix linting error
mayer79 Mar 27, 2024
24e8b1c
Merge branch 'master' into r-api-interaction-constraints
mayer79 Apr 11, 2024
d7c3df3
Merge branch 'master' into r-api-interaction-constraints
mayer79 Apr 20, 2024
f6cce45
Merge branch 'master' into r-api-interaction-constraints
mayer79 May 9, 2024
0265f96
Merge branch 'master' into r-api-interaction-constraints
mayer79 Jun 9, 2024
83b83b1
Check that missing features are added to interaction constraints corr…
mayer79 Jun 10, 2024
e74ff69
Merge branch 'master' into r-api-interaction-constraints
jameslamb Jun 13, 2024
6eb30a2
work around pre-commit issue
jameslamb Jun 13, 2024
b23d5a2
skip pre-commit
jameslamb Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 49 additions & 51 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,68 +59,66 @@

}

# [description]
#
# Besides applying checks, this function
#
# 1. turns feature *names* into 1-based integer positions, then
# 2. adds an extra list element with skipped features, then
# 2. turns 1-based integer positions into 0-based positions, and finally
# 3. collapses the values of each list element into a string like "[0, 1]".
mayer79 marked this conversation as resolved.
Show resolved Hide resolved
#
.check_interaction_constraints <- function(interaction_constraints, column_names) {
if (is.null(interaction_constraints)) {
return(list())
}
if (!identical(class(interaction_constraints), "list")) {
stop("interaction_constraints must be a list")
}

# Convert interaction constraints to feature numbers
string_constraints <- list()
column_indices <- seq_along(column_names)

if (!is.null(interaction_constraints)) {
# Convert feature names to 1-based integer positions and apply checks
for (j in seq_along(interaction_constraints)) {
constraint <- interaction_constraints[[j]]

if (!methods::is(interaction_constraints, "list")) {
stop("interaction_constraints must be a list")
}
constraint_is_character_or_numeric <- sapply(
X = interaction_constraints
, FUN = function(x) {
return(is.character(x) || is.numeric(x))
}
)
if (!all(constraint_is_character_or_numeric)) {
stop("every element in interaction_constraints must be a character vector or numeric vector")
if (is.character(constraint)) {
constraint_indices <- match(constraint, column_names)
} else if (is.numeric(constraint)) {
constraint_indices <- as.integer(constraint)
} else {
stop("every element in interaction_constraints must be a character vector or numeric vector")
}

for (constraint in interaction_constraints) {

# Check for character name
if (is.character(constraint)) {

constraint_indices <- as.integer(match(constraint, column_names) - 1L)

# Provided indices, but some indices are not existing?
if (sum(is.na(constraint_indices)) > 0L) {
stop(
"supplied an unknown feature in interaction_constraints "
, sQuote(constraint[is.na(constraint_indices)])
)
}

} else {

# Check that constraint indices are at most number of features
if (max(constraint) > length(column_names)) {
stop(
"supplied a too large value in interaction_constraints: "
, max(constraint)
, " but only "
, length(column_names)
, " features"
)
}

# Store indices as [0, n-1] indexed instead of [1, n] indexed
constraint_indices <- as.integer(constraint - 1L)

}

# Convert constraint to string
constraint_string <- paste0("[", paste0(constraint_indices, collapse = ","), "]")
string_constraints <- append(string_constraints, constraint_string)
# Features outside range?
bad <- !(constraint_indices %in% column_indices)
if (any(bad)) {
stop(
"supplied an unknown feature in interaction_constraints "
, sQuote(constraint[bad])
)
}

interaction_constraints[[j]] <- constraint_indices
}

return(string_constraints)
# Add missing features as new interaction set
remaining_indices <- setdiff(
column_indices, sort(unique(unlist(interaction_constraints)))
)
if (length(remaining_indices) > 0L) {
interaction_constraints <- c(
interaction_constraints, list(remaining_indices)
)
}

# Turn indices 0-based and convert to string
for (j in seq_along(interaction_constraints)) {
interaction_constraints[[j]] <- paste0(
"[", paste0(interaction_constraints[[j]] - 1L, collapse = ","), "]"
)
}
return(interaction_constraints)
}


Expand Down
4 changes: 2 additions & 2 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -2773,7 +2773,7 @@ test_that(paste0("lgb.train() throws an informative error if the members of inte
}, "every element in interaction_constraints must be a character vector or numeric vector")
})

test_that("lgb.train() throws an informative error if interaction_constraints contains a too large index", {
test_that("lgb.train() throws an error if interaction_constraints contains wrong features", {
mayer79 marked this conversation as resolved.
Show resolved Hide resolved
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression",
interaction_constraints = list(c(1L, length(colnames(train$data)) + 1L), 3L))
Expand All @@ -2783,7 +2783,7 @@ test_that("lgb.train() throws an informative error if interaction_constraints co
, params = params
, nrounds = 2L
)
}, "supplied a too large value in interaction_constraints")
})
})

test_that(paste0("lgb.train() gives same result when interaction_constraints is specified as a list of ",
Expand Down
2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ test_that("Loading a Booster from a text file works", {
, bagging_freq = 1L
, boost_from_average = FALSE
, categorical_feature = c(1L, 2L)
, interaction_constraints = list(c(1L, 2L), 1L)
, interaction_constraints = list(1L:2L, 3L, 4L:ncol(train$data))
mayer79 marked this conversation as resolved.
Show resolved Hide resolved
, feature_contri = rep(0.5, ncol(train$data))
, metric = c("mape", "average_precision")
, learning_rate = 1.0
Expand Down
18 changes: 18 additions & 0 deletions R-package/tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,21 @@ test_that(".equal_or_both_null produces expected results", {
expect_false(.equal_or_both_null(10.0, 1L))
expect_true(.equal_or_both_null(0L, 0L))
})

test_that(".check_interaction_constraints() adds skipped features", {
ref <- letters[1L:5L]
ic_num <- list(1L, c(2L, 3L))
ic_char <- list("a", c("b", "c"))
expected <- list("[0]", "[1,2]", "[3,4]")

ic_checked_num <- .check_interaction_constraints(
interaction_constraints = ic_num, column_names = ref
)

ic_checked_char <- .check_interaction_constraints(
interaction_constraints = ic_char, column_names = ref
)

expect_equal(ic_checked_num, expected)
expect_equal(ic_checked_char, expected)
})