From 0ef3c59ebff1511f5a74f8393563ed56ce4dd5d8 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sat, 18 May 2024 13:56:05 +0800 Subject: [PATCH] Make `@postJson` handle incoming JSON in a streaming manner, introduce `@postJsonCached` (#123) --- cask/src/cask/endpoints/JsonEndpoint.scala | 19 +++++-------------- cask/src/cask/package.scala | 1 + .../formJsonPost/app/src/FormJsonPost.scala | 11 +++++++++++ .../app/test/src/ExampleTests.scala | 13 +++++++++++++ example/httpMethods/build.sc | 1 + 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/cask/src/cask/endpoints/JsonEndpoint.scala b/cask/src/cask/endpoints/JsonEndpoint.scala index 5c421cac65..50591a9007 100644 --- a/cask/src/cask/endpoints/JsonEndpoint.scala +++ b/cask/src/cask/endpoints/JsonEndpoint.scala @@ -42,26 +42,17 @@ object JsonData extends DataCompanion[JsonData]{ } } -class postJson(val path: String, override val subpath: Boolean = false) +class postJsonCached(path: String, subpath: Boolean = false) extends postJsonBase(path, subpath, true) +class postJson(path: String, subpath: Boolean = false) extends postJsonBase(path, subpath, false) +abstract class postJsonBase(val path: String, override val subpath: Boolean = false, cacheBody: Boolean = false) extends HttpEndpoint[Response[JsonData], ujson.Value]{ val methods = Seq("post") type InputParser[T] = JsReader[T] - def wrapFunction(ctx: Request, - delegate: Delegate): Result[Response.Raw] = { + def wrapFunction(ctx: Request, delegate: Delegate): Result[Response.Raw] = { val obj = for{ - str <- - try { - val boas = new ByteArrayOutputStream() - Util.transferTo(ctx.exchange.getInputStream, boas) - Right(new String(boas.toByteArray)) - } - catch{case e: Throwable => Left(cask.model.Response( - "Unable to deserialize input JSON text: " + e + "\n" + Util.stackTraceString(e), - statusCode = 400 - ))} json <- - try Right(ujson.read(str)) + try Right(ujson.read(if (cacheBody) ctx.bytes else ctx.exchange.getInputStream)) catch{case e: Throwable => Left(cask.model.Response( "Input text is invalid JSON: " + e + "\n" + Util.stackTraceString(e), statusCode = 400 diff --git a/cask/src/cask/package.scala b/cask/src/cask/package.scala index 8ef0a1b81d..a5dbbe9a76 100644 --- a/cask/src/cask/package.scala +++ b/cask/src/cask/package.scala @@ -37,6 +37,7 @@ package object cask { type staticFiles = endpoints.staticFiles type staticResources = endpoints.staticResources type postJson = endpoints.postJson + type postJsonCached = endpoints.postJsonCached type getJson = endpoints.getJson type postForm = endpoints.postForm type options = endpoints.options diff --git a/example/formJsonPost/app/src/FormJsonPost.scala b/example/formJsonPost/app/src/FormJsonPost.scala index e43968c927..57c47d8b11 100644 --- a/example/formJsonPost/app/src/FormJsonPost.scala +++ b/example/formJsonPost/app/src/FormJsonPost.scala @@ -13,6 +13,17 @@ object FormJsonPost extends cask.MainRoutes{ ) } + @cask.postJsonCached("/json-obj-cached") + def jsonEndpointObjCached(value1: ujson.Value, value2: Seq[Int], request: cask.Request) = { + ujson.Obj( + "value1" -> value1, + "value2" -> value2, + // `cacheBody = true` buffers up the body of the request in memory before parsing, + // giving you access to the request body data if you want to use it yourself + "body" -> request.text() + ) + } + @cask.postForm("/form") def formEndpoint(value1: cask.FormValue, value2: Seq[Int]) = { "OK " + value1 + " " + value2 diff --git a/example/formJsonPost/app/test/src/ExampleTests.scala b/example/formJsonPost/app/test/src/ExampleTests.scala index adb2ee1b76..7055689da1 100644 --- a/example/formJsonPost/app/test/src/ExampleTests.scala +++ b/example/formJsonPost/app/test/src/ExampleTests.scala @@ -30,6 +30,19 @@ object ExampleTests extends TestSuite{ ) ujson.read(response2.text()) ==> ujson.Obj("value1" -> true, "value2" -> ujson.Arr(3)) + + val response2Cached = requests.post( + s"$host/json-obj-cached", + data = """{"value1": true, "value2": [3]}""" + ) + ujson.read(response2Cached.text()) ==> + ujson.Obj( + "value1" -> true, + "value2" -> ujson.Arr(3), + "body" -> """{"value1": true, "value2": [3]}""" + ) + + val response3 = requests.post( s"$host/form", data = Seq("value1" -> "hello", "value2" -> "1", "value2" -> "2") diff --git a/example/httpMethods/build.sc b/example/httpMethods/build.sc index 75de91ef2f..edb70116ba 100644 --- a/example/httpMethods/build.sc +++ b/example/httpMethods/build.sc @@ -10,5 +10,6 @@ trait AppModule extends CrossScalaModule{ ivy"com.lihaoyi::utest::0.8.1", ivy"com.lihaoyi::requests::0.8.0", ) + def forkArgs = Seq("--add-opens=java.base/java.net=ALL-UNNAMED") } }