Skip to content

Commit

Permalink
GH-16149: fix gam rebalance bug (#16153)
Browse files Browse the repository at this point in the history
* add R test that reproduce the error.

* add more to test.

* GH-16149: fixed key passed to rebalance dataset to avoid collision.

* Adopt adam code review comments.

Co-authored-by: Veronika Maurerová <[email protected]>
  • Loading branch information
wendycwong and maurever committed Apr 22, 2024
1 parent 0c60e3e commit adeebb5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
4 changes: 2 additions & 2 deletions h2o-algos/src/main/java/hex/gam/GAM.java
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ public void computeImpl() {
if (error_count() > 0) // if something goes wrong, let's throw a fit
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(GAM.this);
// add gamified columns to training frame
Frame newTFrame = new Frame(rebalance(adaptTrain(), false, _result+".temporary.train"));
Frame newTFrame = new Frame(rebalance(adaptTrain(), false, Key.make()+".temporary.train"));
verifyGamTransformedFrame(newTFrame);

if (error_count() > 0) // if something goes wrong during gam transformation, let's throw a fit again!
Expand All @@ -937,7 +937,7 @@ public void computeImpl() {
int[] singleGamColsCount = new int[]{_cubicSplineNum, _iSplineNum, _mSplineNum};
_valid = rebalance(adaptValidFrame(_parms.valid(), _valid, _parms, _gamColNamesCenter, _binvD,
_zTranspose, _knots, _zTransposeCS, _allPolyBasisList, _gamColMeansRaw, _oneOGamColStd, singleGamColsCount),
false, _result + ".temporary.valid");
false, Key.make() + ".temporary.valid");
}
DKV.put(newTFrame); // This one will cause deleted vectors if add to Scope.track
Frame newValidFrame = _valid == null ? null : new Frame(_valid);
Expand Down
53 changes: 53 additions & 0 deletions h2o-r/tests/testdir_algos/gam/runit_GH_16149_gam_row_error.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
setwd(normalizePath(dirname(R.utils::commandArgs(asValues=TRUE)$"f")))
source("../../../scripts/h2o-r-test-setup.R")

library(data.table)

# This test was provided by a customer. No exit condition is needed as
# the test before my fix always failed. As long as this test completes
# successfully, it should be good enough.
test.gam.dataset.error <- function(n) {
sum_insured <- seq(1, 200000, length.out = n)
d2 <-
data.table(
sum_insured = sum_insured,
sqrt = sqrt(sum_insured),
sine = sin(2 * pi * sum_insured / 40000)
)
d2[, sine := 0.3 * sqrt * sine , ]
d2[, y := pmax(0, sqrt + sine) , ]

d2[, x := sum_insured]
d2[, x2 := rev(x) , ] # flip axis

# import the dataset
h2o_data2 <- as.h2o(d2)

model2 <-
h2o.gam(
y = "y",
gam_columns = c("x2"),
bs = c(2),
spline_orders = c(3),
splines_non_negative = c(F),
training_frame = h2o_data2,
family = "tweedie",
tweedie_variance_power = 1.1,
scale = c(0),
lambda = 0,
alpha = 0,
keep_gam_cols = T,
non_negative = TRUE,
num_knots = c(10)
)
print("model building completed.")
}

test.model.gam.dataset.error <- function() {
# test for n=1005
test.gam.dataset.error(1005)
# test for n=1001;
test.gam.dataset.error(1001)
}

doTest("General Additive Model dataset size 1001 and 1005 error", test.model.gam.dataset.error)

0 comments on commit adeebb5

Please sign in to comment.