User defined aggregation function

Buildung a customized Recall function
SQL
Oracle
ML
Metric
Author

Johannes Tomasoni

Published

November 7, 2022

About

This article describes the steps to create a customized aggregation function in Oracle.

\[Recall = \frac{TP}{TP+FN}\]

As example Recall is implemented.

Implementation

Define a type for the input values, if it isn’t a native Oracle type (NUMBER, VARCHAR2, …) or if the aggragation function will take more than one input value.

CREATE OR REPLACE TYPE y_pair FORCE AS OBJECT (
  y_true NUMBER, 
  y_pred NUMBER, 
  threshold NUMBER DEFAULT 0.5
);
/

Implement logic of user defined function in a oracle object type.

-- Create type header
CREATE OR REPLACE TYPE RecallImpl AS OBJECT(
  -- define object variables
  tp NUMBER, -- number of true positives
  fn NUMBER, -- number of false negatives

  -- Interface for Oracle userdefined functions. 
  -- Init, iterate and terminate are mandatory
  STATIC FUNCTION ODCIAggregateInitialize(sctx IN OUT RecallImpl) RETURN NUMBER,
  MEMBER FUNCTION ODCIAggregateIterate(self IN OUT RecallImpl, value IN y_pair) RETURN NUMBER,
  MEMBER FUNCTION ODCIAggregateDelete(self IN OUT RecallImpl, value IN y_pair) RETURN NUMBER,
  MEMBER FUNCTION ODCIAggregateTerminate(self IN RecallImpl, recall OUT NUMBER, flags IN NUMBER) RETURN NUMBER,
  MEMBER FUNCTION ODCIAggregateMerge(self IN OUT RecallImpl, ctx2 IN RecallImpl) RETURN NUMBER
);
/
-- Create type body
CREATE OR REPLACE TYPE BODY RecallImpl IS 
  

  STATIC FUNCTION ODCIAggregateInitialize(sctx IN OUT RecallImpl) 
                                                        RETURN NUMBER IS 
  BEGIN

    -- initialize tp and fn with 0
    sctx := RecallImpl(0, 0);
  
    RETURN ODCIConst.Success;
  END;


  -- Increase tp and fn count is condition is satisfied
  MEMBER FUNCTION ODCIAggregateIterate(self IN OUT RecallImpl, 
                                       value IN y_pair) RETURN NUMBER IS
  BEGIN

    IF value.y_true = 1 THEN
      -- increase true positive count
      IF value.y_pred >= value.y_true - value.threshold THEN
        self.tp := self.tp + 1;
      END IF;
      
      -- increase false negative count
      IF value.y_pred < value.y_true - value.threshold THEN
        self.fn := self.fn + 1;
      END IF;
    END IF;
    
    RETURN ODCIConst.Success;
  END;


  -- In case of aggregation in window function RECALL_CAF(...)OVER(ORDER BY...) 
  -- the entry leaving the window needs to decrease tp and fn count 
  -- if condition is satisfied
  MEMBER FUNCTION ODCIAggregateDelete(self IN OUT RecallImpl, 
                                      value IN y_pair) RETURN NUMBER IS
  BEGIN

    IF value.y_true = 1 THEN
      -- decrease true positive count
      IF value.y_pred >= value.y_true - value.threshold THEN
        self.tp := self.tp - 1;
      END IF;
      
      -- decrease false negative count
      IF value.y_pred < value.y_true - value.threshold THEN
        self.fn := self.fn - 1;
      END IF;
    END IF;

    RETURN ODCIConst.Success;
  END;


  -- Calculate *Recall* from tp and fn count
  MEMBER FUNCTION ODCIAggregateTerminate(self IN RecallImpl, 
                                         recall OUT NUMBER, 
                                         flags IN NUMBER) RETURN NUMBER IS
    v_recall number;
  BEGIN

    recall := self.tp / nullif(self.tp + self.fn, 0);

    RETURN ODCIConst.Success;
  BEGIN;


  -- In case of parallel execution combine tp and fn counts 
  -- from parallel threads
  MEMBER FUNCTION ODCIAggregateMerge(self IN OUT RecallImpl, 
                                     ctx2 IN RecallImpl) RETURN NUMBER IS
  BEGIN

    self.tp := self.tp + ctx2.tp;
    self.fn := self.fn + ctx2.fn;
  
    RETURN ODCIConst.Success;
  END;


END;
/

Parallel aggregation of data:

flowchart TB
  A[Init] --> B1[Iter /\nDel]
  B1 --> B1
  A --> B2[Iter /\nDel]
  B2 --> B2
  B1 --> C[Merge]
  B2 --> C[Merge]
  C --> D[Terminate]

Define function that uses the implemented logic.

CREATE OR REPLACE FUNCTION recall_caf(value y_pair) 
                    RETURN NUMBER PARALLEL_ENABLED AGGREGATION USING RecallImpl;
/

The function can now be used as a aggregation function including analytical and window features.

Example: Recall per kfold and two thresholds.

SELECT 
  r.kfold,
  recall_caf(y_pair(y_true => r.y_true, 
                    y_pred => r.y_oof)) recall_0_5,
  recall_caf(y_pair(y_true => r.y_true, 
                    y_pred => r.y_oof, 
                    threshold => 0.7)) recall_0_7
  FROM oof_results r
GROUP BY r.kfold;
KFOLD RECALL_0_5 RECALL_0_7
0 0.72 0.81
1 0.781 0.841
2 0.754 0.81
3 0.79 0.838
4 0.71 0.72

or in detail:

SELECT 
  r.kfold,
  r.y_true,
  r.y_oof,
  recall_caf(y_pair(y_true => r.y_true, 
                    y_pred => r.y_oof)) OVER (PARTITION BY r.kfold) recall_0_5,
  recall_caf(y_pair(y_true => r.y_true, 
                    y_pred => r.y_oof, 
                    threshold => 0.7)) OVER (PARTITION BY r.kfold) recall_0_7
FROM oof_results r;
KFOLD Y_TRUE Y_OOF RECALL_0_5 RECALL_0_7
0 1 0.9768 0.72 0.81
0 0 0.56 0.72 0.81
0 0 0.121 0.72 0.81
0 1 0.433 0.72 0.81
0 1 0.6768 0.72 0.81
2 0 0.0068 0.754 0.87
2 1 0.877 0.754 0.87
2 0 0.1021 0.754 0.87

Resources

More details can be found in the Oracle documentation:

  • https://docs.oracle.com/cd/B10501_01/appdev.920/a96595/dci11agg.htm
  • https://docs.oracle.com/cd/B12037_01/appdev.101/b10800/dciaggref.htm