4343#include "catalog/pg_foreign_server.h"
4444#include "catalog/pg_type.h"
4545#include "catalog/pg_user_mapping.h"
46+ #include "commands/defrem.h"
47+ #include "common/base64.h"
4648#include "executor/spi.h"
4749#include "foreign/foreign.h"
4850#include "funcapi.h"
@@ -126,6 +128,11 @@ static bool is_valid_dblink_option(const PQconninfoOption *options,
126128 const char * option , Oid context );
127129static int applyRemoteGucs (PGconn * conn );
128130static void restoreLocalGucs (int nestlevel );
131+ static bool UseScramPassthrough (ForeignServer * foreign_server , UserMapping * user );
132+ static void appendSCRAMKeysInfo (StringInfo buf );
133+ static bool is_valid_dblink_fdw_option (const PQconninfoOption * options , const char * option ,
134+ Oid context );
135+ static bool dblink_connstr_has_required_scram_options (const char * connstr );
129136
130137/* Global */
131138static remoteConn * pconn = NULL ;
@@ -1964,7 +1971,7 @@ dblink_fdw_validator(PG_FUNCTION_ARGS)
19641971 {
19651972 DefElem * def = (DefElem * ) lfirst (cell );
19661973
1967- if (!is_valid_dblink_option (options , def -> defname , context ))
1974+ if (!is_valid_dblink_fdw_option (options , def -> defname , context ))
19681975 {
19691976 /*
19701977 * Unknown option, or invalid option for the context specified, so
@@ -2596,6 +2603,68 @@ deleteConnection(const char *name)
25962603 errmsg ("undefined connection name" )));
25972604}
25982605
2606+ /*
2607+ * Ensure that require_auth and SCRAM keys are correctly set on connstr.
2608+ * SCRAM keys used to pass-through are coming from the initial connection
2609+ * from the client with the server.
2610+ *
2611+ * All required SCRAM options are set by dblink, so we just need to ensure
2612+ * that these options are not overwritten by the user.
2613+ *
2614+ * See appendSCRAMKeysInfo and its usage for more.
2615+ */
2616+ bool
2617+ dblink_connstr_has_required_scram_options (const char * connstr )
2618+ {
2619+ PQconninfoOption * options ;
2620+ PQconninfoOption * option ;
2621+ bool has_scram_server_key = false;
2622+ bool has_scram_client_key = false;
2623+ bool has_require_auth = false;
2624+ bool has_scram_keys = false;
2625+
2626+ options = PQconninfoParse (connstr , NULL );
2627+ if (options )
2628+ {
2629+ /*
2630+ * Continue iterating even if we found the keys that we need to
2631+ * validate to make sure that there is no other declaration of these
2632+ * keys that can overwrite the first.
2633+ */
2634+ for (option = options ; option -> keyword != NULL ; option ++ )
2635+ {
2636+ if (strcmp (option -> keyword , "require_auth" ) == 0 )
2637+ {
2638+ if (option -> val != NULL && strcmp (option -> val , "scram-sha-256" ) == 0 )
2639+ has_require_auth = true;
2640+ else
2641+ has_require_auth = false;
2642+ }
2643+
2644+ if (strcmp (option -> keyword , "scram_client_key" ) == 0 )
2645+ {
2646+ if (option -> val != NULL && option -> val [0 ] != '\0' )
2647+ has_scram_client_key = true;
2648+ else
2649+ has_scram_client_key = false;
2650+ }
2651+
2652+ if (strcmp (option -> keyword , "scram_server_key" ) == 0 )
2653+ {
2654+ if (option -> val != NULL && option -> val [0 ] != '\0' )
2655+ has_scram_server_key = true;
2656+ else
2657+ has_scram_server_key = false;
2658+ }
2659+ }
2660+ PQconninfoFree (options );
2661+ }
2662+
2663+ has_scram_keys = has_scram_client_key && has_scram_server_key && MyProcPort -> has_scram_keys ;
2664+
2665+ return (has_scram_keys && has_require_auth );
2666+ }
2667+
25992668/*
26002669 * We need to make sure that the connection made used credentials
26012670 * which were provided by the user, so check what credentials were
@@ -2612,6 +2681,18 @@ dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
26122681 if (PQconnectionUsedPassword (conn ) && dblink_connstr_has_pw (connstr ))
26132682 return ;
26142683
2684+ /*
2685+ * Password was not used to connect, check if SCRAM pass-through is in
2686+ * use.
2687+ *
2688+ * If dblink_connstr_has_required_scram_options is true we assume that
2689+ * UseScramPassthrough is also true because the required SCRAM keys are
2690+ * only added if UseScramPassthrough is set, and the user is not allowed
2691+ * to add the SCRAM keys on fdw and user mapping options.
2692+ */
2693+ if (MyProcPort -> has_scram_keys && dblink_connstr_has_required_scram_options (connstr ))
2694+ return ;
2695+
26152696#ifdef ENABLE_GSS
26162697 /* If GSSAPI creds used to connect, make sure it was one delegated */
26172698 if (PQconnectionUsedGSSAPI (conn ) && be_gssapi_get_delegation (MyProcPort ))
@@ -2664,12 +2745,14 @@ dblink_connstr_has_pw(const char *connstr)
26642745}
26652746
26662747/*
2667- * For non-superusers, insist that the connstr specify a password, except
2668- * if GSSAPI credentials have been delegated (and we check that they are used
2669- * for the connection in dblink_security_check later). This prevents a
2670- * password or GSSAPI credentials from being picked up from .pgpass, a
2671- * service file, the environment, etc. We don't want the postgres user's
2672- * passwords or Kerberos credentials to be accessible to non-superusers.
2748+ * For non-superusers, insist that the connstr specify a password, except if
2749+ * GSSAPI credentials have been delegated (and we check that they are used for
2750+ * the connection in dblink_security_check later) or if SCRAM pass-through is
2751+ * being used. This prevents a password or GSSAPI credentials from being
2752+ * picked up from .pgpass, a service file, the environment, etc. We don't want
2753+ * the postgres user's passwords or Kerberos credentials to be accessible to
2754+ * non-superusers. In case of SCRAM pass-through insist that the connstr
2755+ * has the required SCRAM pass-through options.
26732756 */
26742757static void
26752758dblink_connstr_check (const char * connstr )
@@ -2680,6 +2763,9 @@ dblink_connstr_check(const char *connstr)
26802763 if (dblink_connstr_has_pw (connstr ))
26812764 return ;
26822765
2766+ if (MyProcPort -> has_scram_keys && dblink_connstr_has_required_scram_options (connstr ))
2767+ return ;
2768+
26832769#ifdef ENABLE_GSS
26842770 if (be_gssapi_get_delegation (MyProcPort ))
26852771 return ;
@@ -2832,6 +2918,14 @@ get_connect_string(const char *servername)
28322918 if (aclresult != ACLCHECK_OK )
28332919 aclcheck_error (aclresult , OBJECT_FOREIGN_SERVER , foreign_server -> servername );
28342920
2921+ /*
2922+ * First append hardcoded options needed for SCRAM pass-through, so if
2923+ * the user overwrites these options we can ereport on
2924+ * dblink_connstr_check and dblink_security_check.
2925+ */
2926+ if (MyProcPort -> has_scram_keys && UseScramPassthrough (foreign_server , user_mapping ))
2927+ appendSCRAMKeysInfo (& buf );
2928+
28352929 foreach (cell , fdw -> options )
28362930 {
28372931 DefElem * def = lfirst (cell );
@@ -3016,6 +3110,20 @@ is_valid_dblink_option(const PQconninfoOption *options, const char *option,
30163110 return true;
30173111}
30183112
3113+ /*
3114+ * Same as is_valid_dblink_option but also check for only dblink_fdw specific
3115+ * options.
3116+ */
3117+ static bool
3118+ is_valid_dblink_fdw_option (const PQconninfoOption * options , const char * option ,
3119+ Oid context )
3120+ {
3121+ if (strcmp (option , "use_scram_passthrough" ) == 0 )
3122+ return true;
3123+
3124+ return is_valid_dblink_option (options , option , context );
3125+ }
3126+
30193127/*
30203128 * Copy the remote session's values of GUCs that affect datatype I/O
30213129 * and apply them locally in a new GUC nesting level. Returns the new
@@ -3085,3 +3193,66 @@ restoreLocalGucs(int nestlevel)
30853193 if (nestlevel > 0 )
30863194 AtEOXact_GUC (true, nestlevel );
30873195}
3196+
3197+ /*
3198+ * Append SCRAM client key and server key information from the global
3199+ * MyProcPort into the given StringInfo buffer.
3200+ */
3201+ static void
3202+ appendSCRAMKeysInfo (StringInfo buf )
3203+ {
3204+ int len ;
3205+ int encoded_len ;
3206+ char * client_key ;
3207+ char * server_key ;
3208+
3209+ len = pg_b64_enc_len (sizeof (MyProcPort -> scram_ClientKey ));
3210+ /* don't forget the zero-terminator */
3211+ client_key = palloc0 (len + 1 );
3212+ encoded_len = pg_b64_encode ((const char * ) MyProcPort -> scram_ClientKey ,
3213+ sizeof (MyProcPort -> scram_ClientKey ),
3214+ client_key , len );
3215+ if (encoded_len < 0 )
3216+ elog (ERROR , "could not encode SCRAM client key" );
3217+
3218+ len = pg_b64_enc_len (sizeof (MyProcPort -> scram_ServerKey ));
3219+ /* don't forget the zero-terminator */
3220+ server_key = palloc0 (len + 1 );
3221+ encoded_len = pg_b64_encode ((const char * ) MyProcPort -> scram_ServerKey ,
3222+ sizeof (MyProcPort -> scram_ServerKey ),
3223+ server_key , len );
3224+ if (encoded_len < 0 )
3225+ elog (ERROR , "could not encode SCRAM server key" );
3226+
3227+ appendStringInfo (buf , "scram_client_key='%s' " , client_key );
3228+ appendStringInfo (buf , "scram_server_key='%s' " , server_key );
3229+ appendStringInfo (buf , "require_auth='scram-sha-256' " );
3230+
3231+ pfree (client_key );
3232+ pfree (server_key );
3233+ }
3234+
3235+
3236+ static bool
3237+ UseScramPassthrough (ForeignServer * foreign_server , UserMapping * user )
3238+ {
3239+ ListCell * cell ;
3240+
3241+ foreach (cell , foreign_server -> options )
3242+ {
3243+ DefElem * def = lfirst (cell );
3244+
3245+ if (strcmp (def -> defname , "use_scram_passthrough" ) == 0 )
3246+ return defGetBoolean (def );
3247+ }
3248+
3249+ foreach (cell , user -> options )
3250+ {
3251+ DefElem * def = (DefElem * ) lfirst (cell );
3252+
3253+ if (strcmp (def -> defname , "use_scram_passthrough" ) == 0 )
3254+ return defGetBoolean (def );
3255+ }
3256+
3257+ return false;
3258+ }
0 commit comments