From adeebb51757b0cc5da0a525307dfea55bfedcae1 Mon Sep 17 00:00:00 2001 From: wendycwong Date: Mon, 22 Apr 2024 11:25:28 -0700 Subject: [PATCH] GH-16149: fix gam rebalance bug (#16153) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add R test that reproduce the error. * add more to test. * GH-16149: fixed key passed to rebalance dataset to avoid collision. * Adopt adam code review comments. Co-authored-by: Veronika Maurerová --- h2o-algos/src/main/java/hex/gam/GAM.java | 4 +- .../gam/runit_GH_16149_gam_row_error.R | 53 +++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 h2o-r/tests/testdir_algos/gam/runit_GH_16149_gam_row_error.R diff --git a/h2o-algos/src/main/java/hex/gam/GAM.java b/h2o-algos/src/main/java/hex/gam/GAM.java index 25f55742c28d..9f598be447e4 100644 --- a/h2o-algos/src/main/java/hex/gam/GAM.java +++ b/h2o-algos/src/main/java/hex/gam/GAM.java @@ -927,7 +927,7 @@ public void computeImpl() { if (error_count() > 0) // if something goes wrong, let's throw a fit throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(GAM.this); // add gamified columns to training frame - Frame newTFrame = new Frame(rebalance(adaptTrain(), false, _result+".temporary.train")); + Frame newTFrame = new Frame(rebalance(adaptTrain(), false, Key.make()+".temporary.train")); verifyGamTransformedFrame(newTFrame); if (error_count() > 0) // if something goes wrong during gam transformation, let's throw a fit again! @@ -937,7 +937,7 @@ public void computeImpl() { int[] singleGamColsCount = new int[]{_cubicSplineNum, _iSplineNum, _mSplineNum}; _valid = rebalance(adaptValidFrame(_parms.valid(), _valid, _parms, _gamColNamesCenter, _binvD, _zTranspose, _knots, _zTransposeCS, _allPolyBasisList, _gamColMeansRaw, _oneOGamColStd, singleGamColsCount), - false, _result + ".temporary.valid"); + false, Key.make() + ".temporary.valid"); } DKV.put(newTFrame); // This one will cause deleted vectors if add to Scope.track Frame newValidFrame = _valid == null ? null : new Frame(_valid); diff --git a/h2o-r/tests/testdir_algos/gam/runit_GH_16149_gam_row_error.R b/h2o-r/tests/testdir_algos/gam/runit_GH_16149_gam_row_error.R new file mode 100644 index 000000000000..2186599a4205 --- /dev/null +++ b/h2o-r/tests/testdir_algos/gam/runit_GH_16149_gam_row_error.R @@ -0,0 +1,53 @@ +setwd(normalizePath(dirname(R.utils::commandArgs(asValues=TRUE)$"f"))) +source("../../../scripts/h2o-r-test-setup.R") + +library(data.table) + +# This test was provided by a customer. No exit condition is needed as +# the test before my fix always failed. As long as this test completes +# successfully, it should be good enough. +test.gam.dataset.error <- function(n) { + sum_insured <- seq(1, 200000, length.out = n) + d2 <- + data.table( + sum_insured = sum_insured, + sqrt = sqrt(sum_insured), + sine = sin(2 * pi * sum_insured / 40000) + ) + d2[, sine := 0.3 * sqrt * sine , ] + d2[, y := pmax(0, sqrt + sine) , ] + + d2[, x := sum_insured] + d2[, x2 := rev(x) , ] # flip axis + + # import the dataset + h2o_data2 <- as.h2o(d2) + + model2 <- + h2o.gam( + y = "y", + gam_columns = c("x2"), + bs = c(2), + spline_orders = c(3), + splines_non_negative = c(F), + training_frame = h2o_data2, + family = "tweedie", + tweedie_variance_power = 1.1, + scale = c(0), + lambda = 0, + alpha = 0, + keep_gam_cols = T, + non_negative = TRUE, + num_knots = c(10) + ) + print("model building completed.") +} + +test.model.gam.dataset.error <- function() { + # test for n=1005 + test.gam.dataset.error(1005) + # test for n=1001; + test.gam.dataset.error(1001) +} + +doTest("General Additive Model dataset size 1001 and 1005 error", test.model.gam.dataset.error)