Skip to content

Commit

Permalink
feat: apply all function settings as transaction-scoped settings
Browse files Browse the repository at this point in the history
  • Loading branch information
taimoorzaeem committed Jan 24, 2024
1 parent d7246b4 commit d2fb67f
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Added

- #2887, Add Preference `max-affected` to limit affected resources - @taimoorzaeem
- #3061, Apply all function settings as transaction-scoped settings - @taimoorzaeem

### Fixed

Expand Down
50 changes: 27 additions & 23 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,21 @@ import qualified PostgREST.Query as Query
import qualified PostgREST.Response as Response
import qualified PostgREST.Unix as Unix (installSignalHandlers)

import PostgREST.ApiRequest (Action (..), ApiRequest (..),
Mutation (..), Target (..))
import PostgREST.AppState (AppState)
import PostgREST.Auth (AuthResult (..))
import PostgREST.Config (AppConfig (..))
import PostgREST.Config.PgVersion (PgVersion (..))
import PostgREST.Error (Error)
import PostgREST.Query (DbHandler)
import PostgREST.Response.Performance (ServerTiming (..),
serverTimingHeader)
import PostgREST.SchemaCache (SchemaCache (..))
import PostgREST.SchemaCache.Routine (Routine (..))
import PostgREST.Version (docsVersion, prettyVersion)
import PostgREST.ApiRequest (Action (..),
ApiRequest (..),
Mutation (..), Target (..))
import PostgREST.AppState (AppState)
import PostgREST.Auth (AuthResult (..))
import PostgREST.Config (AppConfig (..))
import PostgREST.Config.PgVersion (PgVersion (..))
import PostgREST.Error (Error)
import PostgREST.Query (DbHandler)
import PostgREST.Response.Performance (ServerTiming (..),
serverTimingHeader)
import PostgREST.SchemaCache (SchemaCache (..))
import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..))
import PostgREST.SchemaCache.Routine (Routine (..))
import PostgREST.Version (docsVersion, prettyVersion)

import qualified Data.ByteString.Char8 as BS
import qualified Data.List as L
Expand Down Expand Up @@ -170,43 +172,44 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
case (iAction, iTarget) of
(ActionRead headersOnly, TargetIdent identifier) -> do
(planTime', wrPlan) <- withTiming $ liftEither $ Plan.wrappedReadPlan identifier conf sCache apiReq
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.wrTxMode wrPlan) mempty $ Query.readQuery wrPlan conf apiReq
(respTime', pgrst) <- withTiming $ liftEither $ Response.readResponse wrPlan headersOnly identifier apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationCreate, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationCreate apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.createQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.createResponse identifier mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationUpdate, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationUpdate apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.updateQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.updateResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationSingleUpsert, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationSingleUpsert apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.singleUpsertQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.singleUpsertResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationDelete, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationDelete apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.deleteQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.deleteResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionInvoke invMethod, TargetProc identifier _) -> do
(ActionInvoke invMethod, TargetProc identifier@(QualifiedIdentifier _ proname) _) -> do
let setting = [(y,z) | (x,y,z) <- funcSettings, x == encodeUtf8 proname]
(planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (Plan.crTxMode cPlan) setting $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(respTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do
(planTime', iPlan) <- withTiming $ liftEither $ Plan.inspectPlan apiReq
(txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
(txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl (Plan.ipTxmode iPlan) mempty $ Query.openApiQuery sCache pgVer conf tSchema
(respTime', pgrst) <- withTiming $ liftEither $ Response.openApiResponse (T.decodeUtf8 prettyVersion, docsVersion) headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

Expand All @@ -230,9 +233,10 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
where
roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf)
roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf
runQuery isoLvl timeout mode query =
funcSettings = dbFuncSettings sCache
runQuery isoLvl mode funcSet query =
runDbHandler appState conf isoLvl mode authenticated prepared $ do
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) funcSet apiReq
Query.runPreReq conf
query

Expand Down
2 changes: 2 additions & 0 deletions src/PostgREST/Config/Database.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module PostgREST.Config.Database
, RoleSettings
, RoleIsolationLvl
, TimezoneNames
, FuncSettings
, toIsolationLevel
) where

Expand All @@ -31,6 +32,7 @@ import Protolude
type RoleSettings = (HM.HashMap ByteString (HM.HashMap ByteString ByteString))
type RoleIsolationLvl = HM.HashMap ByteString SQL.IsolationLevel
type TimezoneNames = Set ByteString -- cache timezone names for prefer timezone=
type FuncSettings = [(ByteString,ByteString,ByteString)]

toIsolationLevel :: (Eq a, IsString a) => a -> SQL.IsolationLevel
toIsolationLevel a = case a of
Expand Down
8 changes: 4 additions & 4 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do

-- | Set transaction scoped settings
setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] ->
ApiRequest -> Maybe Text -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
[(ByteString,ByteString)] -> ApiRequest -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings funcSetting ApiRequest{..} = lift $
SQL.statement mempty $ SQL.dynamicallyParameterized
-- To ensure `GRANT SET ON PARAMETER <superuser_setting> TO authenticator` works, the role settings must be set before the impersonated role.
-- Otherwise the GRANT SET would have to be applied to the impersonated role. See https://github.com/PostgREST/postgrest/issues/3045
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql))
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ funcSettingSql ++ appSettingsSql))
HD.noResult configDbPreparedStatements
where
methodSql = setConfigWithConstantName ("request.method", iMethod)
Expand All @@ -264,7 +264,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
roleSettingsSql = setConfigWithDynamicName <$> roleSettings
appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings)
timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences
timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout
funcSettingSql = setConfigWithDynamicName <$> funcSetting
searchPathSql =
let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in
setConfigWithConstantName ("search_path", schemas)
Expand Down
47 changes: 39 additions & 8 deletions src/PostgREST/SchemaCache.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ import Contravariant.Extras (contrazip2)
import Text.InterpolatedString.Perl6 (q)

import PostgREST.Config (AppConfig (..))
import PostgREST.Config.Database (TimezoneNames,
import PostgREST.Config.Database (FuncSettings,
TimezoneNames,
pgVersionStatement,
toIsolationLevel)
import PostgREST.Config.PgVersion (PgVersion, pgVersion100,
pgVersion110,
pgVersion120)
pgVersion120,
pgVersion150)
import PostgREST.SchemaCache.Identifiers (AccessSet, FieldName,
QualifiedIdentifier (..),
RelIdentifier (..),
Expand All @@ -74,24 +76,25 @@ import qualified PostgREST.MediaType as MediaType

import Protolude


data SchemaCache = SchemaCache
{ dbTables :: TablesMap
, dbRelationships :: RelationshipsMap
, dbRoutines :: RoutineMap
, dbRepresentations :: RepresentationsMap
, dbMediaHandlers :: MediaHandlerMap
, dbTimezones :: TimezoneNames
, dbFuncSettings :: FuncSettings
}

instance JSON.ToJSON SchemaCache where
toJSON (SchemaCache tabs rels routs reps _ _) = JSON.object [
toJSON (SchemaCache tabs rels routs reps _ _ _) = JSON.object [
"dbTables" .= JSON.toJSON tabs
, "dbRelationships" .= JSON.toJSON rels
, "dbRoutines" .= JSON.toJSON routs
, "dbRepresentations" .= JSON.toJSON reps
, "dbMediaHandlers" .= JSON.emptyArray
, "dbTimezones" .= JSON.emptyArray
, "dbFuncSettings" .= JSON.emptyArray
]

-- | A view foreign key or primary key dependency detected on its source table
Expand Down Expand Up @@ -145,6 +148,7 @@ querySchemaCache AppConfig{..} = do
reps <- SQL.statement schemas $ dataRepresentations prepared
mHdlers <- SQL.statement schemas $ mediaHandlers pgVer prepared
tzones <- SQL.statement mempty $ timezones prepared
funSets <- SQL.statement mempty $ funcSettings pgVer prepared
_ <-
let sleepCall = SQL.Statement "select pg_sleep($1)" (param HE.int4) HD.noResult prepared in
whenJust configInternalSCSleep (`SQL.statement` sleepCall) -- only used for testing
Expand All @@ -159,6 +163,7 @@ querySchemaCache AppConfig{..} = do
, dbRepresentations = reps
, dbMediaHandlers = HM.union mHdlers initialMediaHandlers -- the custom handlers will override the initial ones
, dbTimezones = tzones
, dbFuncSettings = funSets
}
where
schemas = toList configDbSchemas
Expand Down Expand Up @@ -195,6 +200,7 @@ removeInternal schemas dbStruct =
, dbRepresentations = dbRepresentations dbStruct -- no need to filter, not directly exposed through the API
, dbMediaHandlers = dbMediaHandlers dbStruct
, dbTimezones = dbTimezones dbStruct
, dbFuncSettings = dbFuncSettings dbStruct
}
where
hasInternalJunction ComputedRelationship{} = False
Expand Down Expand Up @@ -297,7 +303,6 @@ decodeFuncs =
<*> (parseVolatility <$> column HD.char)
<*> column HD.bool
<*> nullableColumn (toIsolationLevel <$> HD.text)
<*> nullableColumn HD.text

addKey :: Routine -> (QualifiedIdentifier, Routine)
addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd)
Expand Down Expand Up @@ -431,8 +436,7 @@ funcsSqlQuery pgVer = [q|
bt.oid <> bt.base as rettype_is_composite_alias,
p.provolatile,
p.provariadic > 0 as hasvariadic,
lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level,
lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout
lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level
FROM pg_proc p
LEFT JOIN arguments a ON a.oid = p.oid
JOIN pg_namespace pn ON pn.oid = p.pronamespace
Expand All @@ -442,7 +446,6 @@ funcsSqlQuery pgVer = [q|
LEFT JOIN pg_class comp ON comp.oid = t.typrelid
LEFT JOIN pg_description as d ON d.objoid = p.oid
LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%'
LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%'
WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true)
|] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)")

Expand Down Expand Up @@ -1203,6 +1206,34 @@ timezones = SQL.Statement sql HE.noParams decodeTimezones
decodeTimezones :: HD.Result TimezoneNames
decodeTimezones = S.fromList . map encodeUtf8 <$> HD.rowList (column HD.text)

funcSettings :: PgVersion -> Bool -> SQL.Statement () FuncSettings
funcSettings pgVer = SQL.Statement sql HE.noParams rows
where
sql = [q|
WITH
func_setting AS (
SELECT p.proname, unnest(p.proconfig) AS setting
FROM pg_proc p
),
kv_settings AS (
SELECT
proname,
substr(setting, 1, strpos(setting, '=') - 1) as key,
lower(substr(setting, strpos(setting, '=') + 1)) as value
FROM func_setting
)
SELECT
proname, kv.key AS key, kv.value AS value
FROM kv_settings kv
JOIN pg_settings ps ON ps.name = kv.key |] <>
(if pgVer >= pgVersion150
then "and (ps.context = 'user' or has_parameter_privilege(current_user::regrole::oid, ps.name, 'set'));"
else "and ps.context = 'user';")

Check warning on line 1231 in src/PostgREST/SchemaCache.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/SchemaCache.hs#L1231

Added line #L1231 was not covered by tests

rows :: HD.Result FuncSettings
rows = HD.rowList $ (,,) <$> (encodeUtf8 <$> column HD.text) <*> (encodeUtf8 <$> column HD.text) <*> (encodeUtf8 <$> column HD.text)


param :: HE.Value a -> HE.Params a
param = HE.param . HE.nonNullable

Expand Down
8 changes: 3 additions & 5 deletions src/PostgREST/SchemaCache/Routine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ data Routine = Function
, pdVolatility :: FuncVolatility
, pdHasVariadic :: Bool
, pdIsoLvl :: Maybe SQL.IsolationLevel
, pdTimeout :: Maybe Text
}
deriving (Eq, Show, Generic)
-- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error)
instance JSON.ToJSON Routine where
toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object
toJSON (Function sch nam desc params ret vol hasVar _) = JSON.object
[
"pdSchema" .= sch
, "pdName" .= nam
Expand All @@ -72,7 +71,6 @@ instance JSON.ToJSON Routine where
, "pdReturnType" .= JSON.toJSON ret
, "pdVolatility" .= JSON.toJSON vol
, "pdHasVariadic" .= JSON.toJSON hasVar
, "pdTimeout" .= tout
]

data RoutineParam = RoutineParam
Expand All @@ -86,10 +84,10 @@ data RoutineParam = RoutineParam

-- Order by least number of params in the case of overloaded functions
instance Ord Routine where
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2
| schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT
| schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2)
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2)

-- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine).
-- | It uses a HashMap for a faster lookup.
Expand Down
6 changes: 6 additions & 0 deletions test/io/fixtures.sql
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,9 @@ $$ language sql set statement_timeout = '4s';
create function get_postgres_version() returns int as $$
select current_setting('server_version_num')::int;
$$ language sql;

GRANT SET ON PARAMETER log_min_duration_sample TO postgrest_test_anonymous;

create or replace function log_min_duration_test() returns text as $$
select current_setting('log_min_duration_sample',false);
$$ language sql set log_min_duration_sample = '5s';

0 comments on commit d2fb67f

Please sign in to comment.