diff options
author | Tom Lane | 2016-06-22 20:52:41 +0000 |
---|---|---|
committer | Tom Lane | 2016-06-22 20:52:41 +0000 |
commit | f8ace5477ef9731ef605f58d313c4cd1548f12d2 (patch) | |
tree | f2c4c43a145eb9c16af539de4748afb5b9cb423d /src/backend | |
parent | e45e990e4b547f05bdb46e4596d24abbaef60043 (diff) |
Fix type-safety problem with parallel aggregate serial/deserialization.
The original specification for this called for the deserialization function
to have signature "deserialize(serialtype) returns transtype", which is a
security violation if transtype is INTERNAL (which it always would be in
practice) and serialtype is not (which ditto). The patch blithely overrode
the opr_sanity check for that, which was sloppy-enough work in itself,
but the indisputable reason this cannot be allowed to stand is that CREATE
FUNCTION will reject such a signature and thus it'd be impossible for
extensions to create parallelizable aggregates.
The minimum fix to make the signature type-safe is to add a second, dummy
argument of type INTERNAL. But to lock it down a bit more and make misuse
of INTERNAL-accepting functions less likely, let's get rid of the ability
to specify a "serialtype" for an aggregate and just say that the only
useful serialtype is BYTEA --- which, in practice, is the only interesting
value anyway, due to the usefulness of the send/recv infrastructure for
this purpose. That means we only have to allow "serialize(internal)
returns bytea" and "deserialize(bytea, internal) returns internal" as
the signatures for these support functions.
In passing fix bogus signature of int4_avg_combine, which I found thanks
to adding an opr_sanity check on combinefunc signatures.
catversion bump due to removing pg_aggregate.aggserialtype and adjusting
signatures of assorted built-in functions.
David Rowley and Tom Lane
Discussion: <[email protected]>
Diffstat (limited to 'src/backend')
-rw-r--r-- | src/backend/catalog/pg_aggregate.c | 27 | ||||
-rw-r--r-- | src/backend/commands/aggregatecmds.c | 69 | ||||
-rw-r--r-- | src/backend/executor/nodeAgg.c | 55 | ||||
-rw-r--r-- | src/backend/optimizer/util/clauses.c | 8 | ||||
-rw-r--r-- | src/backend/optimizer/util/tlist.c | 37 | ||||
-rw-r--r-- | src/backend/parser/parse_agg.c | 136 | ||||
-rw-r--r-- | src/backend/utils/adt/numeric.c | 61 |
7 files changed, 139 insertions, 254 deletions
diff --git a/src/backend/catalog/pg_aggregate.c b/src/backend/catalog/pg_aggregate.c index 73d19ec3947..959d3845df2 100644 --- a/src/backend/catalog/pg_aggregate.c +++ b/src/backend/catalog/pg_aggregate.c @@ -67,7 +67,6 @@ AggregateCreate(const char *aggName, bool mfinalfnExtraArgs, List *aggsortopName, Oid aggTransType, - Oid aggSerialType, int32 aggTransSpace, Oid aggmTransType, int32 aggmTransSpace, @@ -440,44 +439,42 @@ AggregateCreate(const char *aggName, } /* - * Validate the serialization function, if present. We must ensure that - * the return type of this function is the same as the specified - * serialType. + * Validate the serialization function, if present. */ if (aggserialfnName) { - fnArgs[0] = aggTransType; + fnArgs[0] = INTERNALOID; serialfn = lookup_agg_function(aggserialfnName, 1, fnArgs, variadicArgType, &rettype); - if (rettype != aggSerialType) + if (rettype != BYTEAOID) ereport(ERROR, (errcode(ERRCODE_DATATYPE_MISMATCH), errmsg("return type of serialization function %s is not %s", NameListToString(aggserialfnName), - format_type_be(aggSerialType)))); + format_type_be(BYTEAOID)))); } /* - * Validate the deserialization function, if present. We must ensure that - * the return type of this function is the same as the transType. + * Validate the deserialization function, if present. */ if (aggdeserialfnName) { - fnArgs[0] = aggSerialType; + fnArgs[0] = BYTEAOID; + fnArgs[1] = INTERNALOID; /* dummy argument for type safety */ - deserialfn = lookup_agg_function(aggdeserialfnName, 1, + deserialfn = lookup_agg_function(aggdeserialfnName, 2, fnArgs, variadicArgType, &rettype); - if (rettype != aggTransType) + if (rettype != INTERNALOID) ereport(ERROR, (errcode(ERRCODE_DATATYPE_MISMATCH), errmsg("return type of deserialization function %s is not %s", NameListToString(aggdeserialfnName), - format_type_be(aggTransType)))); + format_type_be(INTERNALOID)))); } /* @@ -661,7 +658,6 @@ AggregateCreate(const char *aggName, values[Anum_pg_aggregate_aggmfinalextra - 1] = BoolGetDatum(mfinalfnExtraArgs); values[Anum_pg_aggregate_aggsortop - 1] = ObjectIdGetDatum(sortop); values[Anum_pg_aggregate_aggtranstype - 1] = ObjectIdGetDatum(aggTransType); - values[Anum_pg_aggregate_aggserialtype - 1] = ObjectIdGetDatum(aggSerialType); values[Anum_pg_aggregate_aggtransspace - 1] = Int32GetDatum(aggTransSpace); values[Anum_pg_aggregate_aggmtranstype - 1] = ObjectIdGetDatum(aggmTransType); values[Anum_pg_aggregate_aggmtransspace - 1] = Int32GetDatum(aggmTransSpace); @@ -688,8 +684,7 @@ AggregateCreate(const char *aggName, * Create dependencies for the aggregate (above and beyond those already * made by ProcedureCreate). Note: we don't need an explicit dependency * on aggTransType since we depend on it indirectly through transfn. - * Likewise for aggmTransType using the mtransfunc, and also for - * aggSerialType using the serialfn, if they exist. + * Likewise for aggmTransType using the mtransfunc, if it exists. */ /* Depends on transition function */ diff --git a/src/backend/commands/aggregatecmds.c b/src/backend/commands/aggregatecmds.c index f1fdc1a3603..d34c82c5baf 100644 --- a/src/backend/commands/aggregatecmds.c +++ b/src/backend/commands/aggregatecmds.c @@ -72,7 +72,6 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, List *sortoperatorName = NIL; TypeName *baseType = NULL; TypeName *transType = NULL; - TypeName *serialType = NULL; TypeName *mtransType = NULL; int32 transSpace = 0; int32 mtransSpace = 0; @@ -88,7 +87,6 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, List *parameterDefaults; Oid variadicArgType; Oid transTypeId; - Oid serialTypeId = InvalidOid; Oid mtransTypeId = InvalidOid; char transTypeType; char mtransTypeType = 0; @@ -164,8 +162,6 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, } else if (pg_strcasecmp(defel->defname, "stype") == 0) transType = defGetTypeName(defel); - else if (pg_strcasecmp(defel->defname, "serialtype") == 0) - serialType = defGetTypeName(defel); else if (pg_strcasecmp(defel->defname, "stype1") == 0) transType = defGetTypeName(defel); else if (pg_strcasecmp(defel->defname, "sspace") == 0) @@ -333,73 +329,25 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, format_type_be(transTypeId)))); } - if (serialType) + if (serialfuncName && deserialfuncName) { /* - * There's little point in having a serialization/deserialization - * function on aggregates that don't have an internal state, so let's - * just disallow this as it may help clear up any confusion or - * needless authoring of these functions. + * Serialization is only needed/allowed for transtype INTERNAL. */ if (transTypeId != INTERNALOID) ereport(ERROR, (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("a serialization type must only be specified when the aggregate transition data type is %s", + errmsg("serialization functions may be specified only when the aggregate transition data type is %s", format_type_be(INTERNALOID)))); - - serialTypeId = typenameTypeId(NULL, serialType); - - if (get_typtype(mtransTypeId) == TYPTYPE_PSEUDO && - !IsPolymorphicType(serialTypeId)) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate serialization data type cannot be %s", - format_type_be(serialTypeId)))); - - /* - * We disallow INTERNAL serialType as the whole point of the - * serialized types is to allow the aggregate state to be output, and - * we cannot output INTERNAL. This check, combined with the one above - * ensures that the trans type and serialization type are not the - * same. - */ - if (serialTypeId == INTERNALOID) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate serialization data type cannot be %s", - format_type_be(serialTypeId)))); - - /* - * If serialType is specified then serialfuncName and deserialfuncName - * must be present; if not, then none of the serialization options - * should have been specified. - */ - if (serialfuncName == NIL) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate serialization function must be specified when serialization type is specified"))); - - if (deserialfuncName == NIL) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate deserialization function must be specified when serialization type is specified"))); } - else + else if (serialfuncName || deserialfuncName) { /* - * If serialization type was not specified then there shouldn't be a - * serialization function. + * Cannot specify one function without the other. */ - if (serialfuncName != NIL) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("must specify serialization type when specifying serialization function"))); - - /* likewise for the deserialization function */ - if (deserialfuncName != NIL) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("must specify serialization type when specifying deserialization function"))); + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("must specify both or neither of serialization and deserialization functions"))); } /* @@ -493,7 +441,6 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, mfinalfuncExtraArgs, sortoperatorName, /* sort operator name */ transTypeId, /* transition data type */ - serialTypeId, /* serialization data type */ transSpace, /* transition space */ mtransTypeId, /* transition data type */ mtransSpace, /* transition space */ diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 7b282dec7da..a4479646129 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -514,10 +514,9 @@ static Datum GetAggInitVal(Datum textInitVal, Oid transtype); static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggsate, EState *estate, Aggref *aggref, Oid aggtransfn, Oid aggtranstype, - Oid aggserialtype, Oid aggserialfn, - Oid aggdeserialfn, Datum initValue, - bool initValueIsNull, Oid *inputTypes, - int numArguments); + Oid aggserialfn, Oid aggdeserialfn, + Datum initValue, bool initValueIsNull, + Oid *inputTypes, int numArguments); static int find_compatible_peragg(Aggref *newagg, AggState *aggstate, int lastaggno, List **same_input_transnos); static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, @@ -996,6 +995,9 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup) dsinfo->arg[0] = slot->tts_values[0]; dsinfo->argnull[0] = slot->tts_isnull[0]; + /* Dummy second argument for type-safety reasons */ + dsinfo->arg[1] = PointerGetDatum(NULL); + dsinfo->argnull[1] = false; /* * We run the deserialization functions in per-input-tuple @@ -2669,8 +2671,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) AclResult aclresult; Oid transfn_oid, finalfn_oid; - Oid serialtype_oid, - serialfn_oid, + Oid serialfn_oid, deserialfn_oid; Expr *finalfnexpr; Oid aggtranstype; @@ -2740,7 +2741,6 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) else peragg->finalfn_oid = finalfn_oid = InvalidOid; - serialtype_oid = InvalidOid; serialfn_oid = InvalidOid; deserialfn_oid = InvalidOid; @@ -2753,13 +2753,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) { /* * The planner should only have generated an agg node with - * serialStates if every aggregate with an INTERNAL state has a - * serialization type, serialization function and deserialization - * function. Let's ensure it didn't mess that up. + * serialStates if every aggregate with an INTERNAL state has + * serialization/deserialization functions. Verify that. */ - if (!OidIsValid(aggform->aggserialtype)) - elog(ERROR, "serialtype not set during serialStates aggregation step"); - if (!OidIsValid(aggform->aggserialfn)) elog(ERROR, "serialfunc not set during serialStates aggregation step"); @@ -2768,17 +2764,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) /* serialization func only required when not finalizing aggs */ if (!aggstate->finalizeAggs) - { serialfn_oid = aggform->aggserialfn; - serialtype_oid = aggform->aggserialtype; - } /* deserialization func only required when combining states */ if (aggstate->combineStates) - { deserialfn_oid = aggform->aggdeserialfn; - serialtype_oid = aggform->aggserialtype; - } } /* Check that aggregate owner has permission to call component fns */ @@ -2906,10 +2896,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) pertrans = &pertransstates[++transno]; build_pertrans_for_aggref(pertrans, aggstate, estate, aggref, transfn_oid, aggtranstype, - serialtype_oid, serialfn_oid, - deserialfn_oid, initValue, - initValueIsNull, inputTypes, - numArguments); + serialfn_oid, deserialfn_oid, + initValue, initValueIsNull, + inputTypes, numArguments); peragg->transno = transno; } ReleaseSysCache(aggTuple); @@ -2937,7 +2926,7 @@ static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggstate, EState *estate, Aggref *aggref, - Oid aggtransfn, Oid aggtranstype, Oid aggserialtype, + Oid aggtransfn, Oid aggtranstype, Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, Oid *inputTypes, int numArguments) @@ -3065,10 +3054,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, if (OidIsValid(aggserialfn)) { - build_aggregate_serialfn_expr(aggtranstype, - aggserialtype, - aggref->inputcollid, - aggserialfn, + build_aggregate_serialfn_expr(aggserialfn, &serialfnexpr); fmgr_info(aggserialfn, &pertrans->serialfn); fmgr_info_set_expr((Node *) serialfnexpr, &pertrans->serialfn); @@ -3076,24 +3062,21 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, InitFunctionCallInfoData(pertrans->serialfn_fcinfo, &pertrans->serialfn, 1, - pertrans->aggCollation, + InvalidOid, (void *) aggstate, NULL); } if (OidIsValid(aggdeserialfn)) { - build_aggregate_serialfn_expr(aggserialtype, - aggtranstype, - aggref->inputcollid, - aggdeserialfn, - &deserialfnexpr); + build_aggregate_deserialfn_expr(aggdeserialfn, + &deserialfnexpr); fmgr_info(aggdeserialfn, &pertrans->deserialfn); fmgr_info_set_expr((Node *) deserialfnexpr, &pertrans->deserialfn); InitFunctionCallInfoData(pertrans->deserialfn_fcinfo, &pertrans->deserialfn, - 1, - pertrans->aggCollation, + 2, + InvalidOid, (void *) aggstate, NULL); } diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index 0e738c1ccc0..7138cad31d8 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -465,13 +465,11 @@ aggregates_allow_partial_walker(Node *node, partial_agg_context *context) /* * If we find any aggs with an internal transtype then we must check - * that these have a serialization type, serialization func and - * deserialization func; otherwise, we set the maximum allowed type to - * PAT_INTERNAL_ONLY. + * whether these have serialization/deserialization functions; + * otherwise, we set the maximum allowed type to PAT_INTERNAL_ONLY. */ if (aggform->aggtranstype == INTERNALOID && - (!OidIsValid(aggform->aggserialtype) || - !OidIsValid(aggform->aggserialfn) || + (!OidIsValid(aggform->aggserialfn) || !OidIsValid(aggform->aggdeserialfn))) context->allowedtype = PAT_INTERNAL_ONLY; diff --git a/src/backend/optimizer/util/tlist.c b/src/backend/optimizer/util/tlist.c index de0a8c7b57f..5fa80ac51be 100644 --- a/src/backend/optimizer/util/tlist.c +++ b/src/backend/optimizer/util/tlist.c @@ -15,7 +15,7 @@ #include "postgres.h" #include "access/htup_details.h" -#include "catalog/pg_aggregate.h" +#include "catalog/pg_type.h" #include "nodes/makefuncs.h" #include "nodes/nodeFuncs.h" #include "optimizer/tlist.h" @@ -766,8 +766,8 @@ apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target) /* * apply_partialaggref_adjustment * Convert PathTarget to be suitable for a partial aggregate node. We simply - * adjust any Aggref nodes found in the target and set the aggoutputtype to - * the aggtranstype or aggserialtype. This allows exprType() to return the + * adjust any Aggref nodes found in the target and set the aggoutputtype + * appropriately. This allows exprType() to return the * actual type that will be produced. * * Note: We expect 'target' to be a flat target list and not have Aggrefs buried @@ -784,40 +784,29 @@ apply_partialaggref_adjustment(PathTarget *target) if (IsA(aggref, Aggref)) { - HeapTuple aggTuple; - Form_pg_aggregate aggform; Aggref *newaggref; - aggTuple = SearchSysCache1(AGGFNOID, - ObjectIdGetDatum(aggref->aggfnoid)); - if (!HeapTupleIsValid(aggTuple)) - elog(ERROR, "cache lookup failed for aggregate %u", - aggref->aggfnoid); - aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); - newaggref = (Aggref *) copyObject(aggref); /* - * Use the serialization type, if one exists. Note that we don't - * support it being a polymorphic type. (XXX really we ought to - * hardwire this as INTERNAL -> BYTEA, and avoid a catalog lookup - * here altogether?) + * Normally, a partial aggregate returns the aggregate's + * transition type, but if that's INTERNAL, it returns BYTEA + * instead. (XXX this assumes we're doing parallel aggregate with + * serialization; later we might need an argument to tell this + * function whether we're doing parallel or just local partial + * aggregation.) */ - if (OidIsValid(aggform->aggserialtype)) - newaggref->aggoutputtype = aggform->aggserialtype; + Assert(OidIsValid(newaggref->aggtranstype)); + + if (newaggref->aggtranstype == INTERNALOID) + newaggref->aggoutputtype = BYTEAOID; else - { - /* Otherwise, we return the aggregate's transition type */ - Assert(OidIsValid(newaggref->aggtranstype)); newaggref->aggoutputtype = newaggref->aggtranstype; - } /* flag it as partial */ newaggref->aggpartial = true; lfirst(lc) = newaggref; - - ReleaseSysCache(aggTuple); } } } diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index b9ca066698e..481a4ddc484 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -71,6 +71,8 @@ static bool finalize_grouping_exprs_walker(Node *node, check_ungrouped_columns_context *context); static void check_agglevels_and_constraints(ParseState *pstate, Node *expr); static List *expand_groupingset_node(GroupingSet *gs); +static Node *make_agg_arg(Oid argtype, Oid argcollation); + /* * transformAggregateCall - @@ -1863,37 +1865,19 @@ build_aggregate_transfn_expr(Oid *agg_input_types, Expr **transfnexpr, Expr **invtransfnexpr) { - Param *argp; List *args; FuncExpr *fexpr; int i; /* - * Build arg list to use in the transfn FuncExpr node. We really only care - * that transfn can discover the actual argument types at runtime using - * get_fn_expr_argtype(), so it's okay to use Param nodes that don't - * correspond to any real Param. + * Build arg list to use in the transfn FuncExpr node. */ - argp = makeNode(Param); - argp->paramkind = PARAM_EXEC; - argp->paramid = -1; - argp->paramtype = agg_state_type; - argp->paramtypmod = -1; - argp->paramcollid = agg_input_collation; - argp->location = -1; - - args = list_make1(argp); + args = list_make1(make_agg_arg(agg_state_type, agg_input_collation)); for (i = agg_num_direct_inputs; i < agg_num_inputs; i++) { - argp = makeNode(Param); - argp->paramkind = PARAM_EXEC; - argp->paramid = -1; - argp->paramtype = agg_input_types[i]; - argp->paramtypmod = -1; - argp->paramcollid = agg_input_collation; - argp->location = -1; - args = lappend(args, argp); + args = lappend(args, + make_agg_arg(agg_input_types[i], agg_input_collation)); } fexpr = makeFuncExpr(transfn_oid, @@ -1936,20 +1920,13 @@ build_aggregate_combinefn_expr(Oid agg_state_type, Oid combinefn_oid, Expr **combinefnexpr) { - Param *argp; + Node *argp; List *args; FuncExpr *fexpr; - /* Build arg list to use in the combinefn FuncExpr node. */ - argp = makeNode(Param); - argp->paramkind = PARAM_EXEC; - argp->paramid = -1; - argp->paramtype = agg_state_type; - argp->paramtypmod = -1; - argp->paramcollid = agg_input_collation; - argp->location = -1; + /* combinefn takes two arguments of the aggregate state type */ + argp = make_agg_arg(agg_state_type, agg_input_collation); - /* transition state type is arg 1 and 2 */ args = list_make2(argp, argp); fexpr = makeFuncExpr(combinefn_oid, @@ -1958,51 +1935,59 @@ build_aggregate_combinefn_expr(Oid agg_state_type, InvalidOid, agg_input_collation, COERCE_EXPLICIT_CALL); - fexpr->funcvariadic = false; + /* combinefn is currently never treated as variadic */ *combinefnexpr = (Expr *) fexpr; } /* * Like build_aggregate_transfn_expr, but creates an expression tree for the - * serialization or deserialization function of an aggregate, rather than the - * transition function. This may be used for either the serialization or - * deserialization function by swapping the first two parameters over. + * serialization function of an aggregate. */ void -build_aggregate_serialfn_expr(Oid agg_input_type, - Oid agg_output_type, - Oid agg_input_collation, - Oid serialfn_oid, +build_aggregate_serialfn_expr(Oid serialfn_oid, Expr **serialfnexpr) { - Param *argp; List *args; FuncExpr *fexpr; - /* Build arg list to use in the FuncExpr node. */ - argp = makeNode(Param); - argp->paramkind = PARAM_EXEC; - argp->paramid = -1; - argp->paramtype = agg_input_type; - argp->paramtypmod = -1; - argp->paramcollid = agg_input_collation; - argp->location = -1; - - /* takes a single arg of the agg_input_type */ - args = list_make1(argp); + /* serialfn always takes INTERNAL and returns BYTEA */ + args = list_make1(make_agg_arg(INTERNALOID, InvalidOid)); fexpr = makeFuncExpr(serialfn_oid, - agg_output_type, + BYTEAOID, args, InvalidOid, - agg_input_collation, + InvalidOid, COERCE_EXPLICIT_CALL); - fexpr->funcvariadic = false; *serialfnexpr = (Expr *) fexpr; } /* * Like build_aggregate_transfn_expr, but creates an expression tree for the + * deserialization function of an aggregate. + */ +void +build_aggregate_deserialfn_expr(Oid deserialfn_oid, + Expr **deserialfnexpr) +{ + List *args; + FuncExpr *fexpr; + + /* deserialfn always takes BYTEA, INTERNAL and returns INTERNAL */ + args = list_make2(make_agg_arg(BYTEAOID, InvalidOid), + make_agg_arg(INTERNALOID, InvalidOid)); + + fexpr = makeFuncExpr(deserialfn_oid, + INTERNALOID, + args, + InvalidOid, + InvalidOid, + COERCE_EXPLICIT_CALL); + *deserialfnexpr = (Expr *) fexpr; +} + +/* + * Like build_aggregate_transfn_expr, but creates an expression tree for the * final function of an aggregate, rather than the transition function. */ void @@ -2014,33 +1999,19 @@ build_aggregate_finalfn_expr(Oid *agg_input_types, Oid finalfn_oid, Expr **finalfnexpr) { - Param *argp; List *args; int i; /* * Build expr tree for final function */ - argp = makeNode(Param); - argp->paramkind = PARAM_EXEC; - argp->paramid = -1; - argp->paramtype = agg_state_type; - argp->paramtypmod = -1; - argp->paramcollid = agg_input_collation; - argp->location = -1; - args = list_make1(argp); + args = list_make1(make_agg_arg(agg_state_type, agg_input_collation)); /* finalfn may take additional args, which match agg's input types */ for (i = 0; i < num_finalfn_inputs - 1; i++) { - argp = makeNode(Param); - argp->paramkind = PARAM_EXEC; - argp->paramid = -1; - argp->paramtype = agg_input_types[i]; - argp->paramtypmod = -1; - argp->paramcollid = agg_input_collation; - argp->location = -1; - args = lappend(args, argp); + args = lappend(args, + make_agg_arg(agg_input_types[i], agg_input_collation)); } *finalfnexpr = (Expr *) makeFuncExpr(finalfn_oid, @@ -2051,3 +2022,24 @@ build_aggregate_finalfn_expr(Oid *agg_input_types, COERCE_EXPLICIT_CALL); /* finalfn is currently never treated as variadic */ } + +/* + * Convenience function to build dummy argument expressions for aggregates. + * + * We really only care that an aggregate support function can discover its + * actual argument types at runtime using get_fn_expr_argtype(), so it's okay + * to use Param nodes that don't correspond to any real Param. + */ +static Node * +make_agg_arg(Oid argtype, Oid argcollation) +{ + Param *argp = makeNode(Param); + + argp->paramkind = PARAM_EXEC; + argp->paramid = -1; + argp->paramtype = argtype; + argp->paramtypmod = -1; + argp->paramcollid = argcollation; + argp->location = -1; + return (Node *) argp; +} diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c index 6592ef4d2d9..f0b3b87f4c3 100644 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -3510,10 +3510,7 @@ numeric_avg_combine(PG_FUNCTION_ARGS) /* * numeric_avg_serialize * Serialize NumericAggState for numeric aggregates that don't require - * sumX2. Serializes NumericAggState into bytea using the standard pq API. - * - * numeric_avg_deserialize(numeric_avg_serialize(state)) must result in a state - * which matches the original input state. + * sumX2. */ Datum numeric_avg_serialize(PG_FUNCTION_ARGS) @@ -3564,17 +3561,13 @@ numeric_avg_serialize(PG_FUNCTION_ARGS) /* * numeric_avg_deserialize - * Deserialize bytea into NumericAggState for numeric aggregates that - * don't require sumX2. Deserializes bytea into NumericAggState using the - * standard pq API. - * - * numeric_avg_serialize(numeric_avg_deserialize(bytea)) must result in a value - * which matches the original bytea value. + * Deserialize bytea into NumericAggState for numeric aggregates that + * don't require sumX2. */ Datum numeric_avg_deserialize(PG_FUNCTION_ARGS) { - bytea *sstate = PG_GETARG_BYTEA_P(0); + bytea *sstate; NumericAggState *result; Datum temp; StringInfoData buf; @@ -3582,6 +3575,8 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS) if (!AggCheckCallContext(fcinfo, NULL)) elog(ERROR, "aggregate function called in non-aggregate context"); + sstate = PG_GETARG_BYTEA_P(0); + /* * Copy the bytea into a StringInfo so that we can "receive" it using the * standard pq API. @@ -3619,11 +3614,7 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS) /* * numeric_serialize * Serialization function for NumericAggState for numeric aggregates that - * require sumX2. Serializes NumericAggState into bytea using the standard - * pq API. - * - * numeric_deserialize(numeric_serialize(state)) must result in a state which - * matches the original input state. + * require sumX2. */ Datum numeric_serialize(PG_FUNCTION_ARGS) @@ -3683,16 +3674,12 @@ numeric_serialize(PG_FUNCTION_ARGS) /* * numeric_deserialize * Deserialization function for NumericAggState for numeric aggregates that - * require sumX2. Deserializes bytea into into NumericAggState using the - * standard pq API. - * - * numeric_serialize(numeric_deserialize(bytea)) must result in a value which - * matches the original bytea value. + * require sumX2. */ Datum numeric_deserialize(PG_FUNCTION_ARGS) { - bytea *sstate = PG_GETARG_BYTEA_P(0); + bytea *sstate; NumericAggState *result; Datum temp; StringInfoData buf; @@ -3700,6 +3687,8 @@ numeric_deserialize(PG_FUNCTION_ARGS) if (!AggCheckCallContext(fcinfo, NULL)) elog(ERROR, "aggregate function called in non-aggregate context"); + sstate = PG_GETARG_BYTEA_P(0); + /* * Copy the bytea into a StringInfo so that we can "receive" it using the * standard pq API. @@ -3992,11 +3981,8 @@ numeric_poly_combine(PG_FUNCTION_ARGS) /* * numeric_poly_serialize - * Serialize PolyNumAggState into bytea using the standard pq API for - * aggregate functions which require sumX2. - * - * numeric_poly_deserialize(numeric_poly_serialize(state)) must result in a - * state which matches the original input state. + * Serialize PolyNumAggState into bytea for aggregate functions which + * require sumX2. */ Datum numeric_poly_serialize(PG_FUNCTION_ARGS) @@ -4067,16 +4053,13 @@ numeric_poly_serialize(PG_FUNCTION_ARGS) /* * numeric_poly_deserialize - * Deserialize PolyNumAggState from bytea using the standard pq API for - * aggregate functions which require sumX2. - * - * numeric_poly_serialize(numeric_poly_deserialize(bytea)) must result in a - * state which matches the original input state. + * Deserialize PolyNumAggState from bytea for aggregate functions which + * require sumX2. */ Datum numeric_poly_deserialize(PG_FUNCTION_ARGS) { - bytea *sstate = PG_GETARG_BYTEA_P(0); + bytea *sstate; PolyNumAggState *result; Datum sumX; Datum sumX2; @@ -4085,6 +4068,8 @@ numeric_poly_deserialize(PG_FUNCTION_ARGS) if (!AggCheckCallContext(fcinfo, NULL)) elog(ERROR, "aggregate function called in non-aggregate context"); + sstate = PG_GETARG_BYTEA_P(0); + /* * Copy the bytea into a StringInfo so that we can "receive" it using the * standard pq API. @@ -4226,9 +4211,6 @@ int8_avg_combine(PG_FUNCTION_ARGS) /* * int8_avg_serialize * Serialize PolyNumAggState into bytea using the standard pq API. - * - * int8_avg_deserialize(int8_avg_serialize(state)) must result in a state which - * matches the original input state. */ Datum int8_avg_serialize(PG_FUNCTION_ARGS) @@ -4286,14 +4268,11 @@ int8_avg_serialize(PG_FUNCTION_ARGS) /* * int8_avg_deserialize * Deserialize bytea back into PolyNumAggState. - * - * int8_avg_serialize(int8_avg_deserialize(bytea)) must result in a value which - * matches the original bytea value. */ Datum int8_avg_deserialize(PG_FUNCTION_ARGS) { - bytea *sstate = PG_GETARG_BYTEA_P(0); + bytea *sstate; PolyNumAggState *result; StringInfoData buf; Datum temp; @@ -4301,6 +4280,8 @@ int8_avg_deserialize(PG_FUNCTION_ARGS) if (!AggCheckCallContext(fcinfo, NULL)) elog(ERROR, "aggregate function called in non-aggregate context"); + sstate = PG_GETARG_BYTEA_P(0); + /* * Copy the bytea into a StringInfo so that we can "receive" it using the * standard pq API. |