diff --git a/src/DIRAC/Core/Utilities/MySQL.py b/src/DIRAC/Core/Utilities/MySQL.py index e69af1a5fbc..d48ecccd349 100755 --- a/src/DIRAC/Core/Utilities/MySQL.py +++ b/src/DIRAC/Core/Utilities/MySQL.py @@ -138,6 +138,15 @@ Count the number of records on each distinct combination of AttrList, selected with condition defined by condDict and time stamps +getCounters( self, table, attrList, condDict = None, older = None, + newer = None, timeStamp = None, connection = False, + greater = None, smaller = None, inner_join = "" ): + + Count the number of records on each distinct combination of AttrList, selected + with condition defined by condDict and time stamps. Optional greater/smaller + allow for inequality filters and inner_join lets callers inject a JOIN on a + temporary or auxiliary table (e.g. large ID sets) for performance. + getDistinctAttributeValues( self, table, attribute, condDict = None, older = None, newer = None, timeStamp = None, connection = False ): @@ -1148,6 +1157,7 @@ def getCounters( connection=False, greater=None, smaller=None, + inner_join="", ): """ Count the number of records on each distinct combination of AttrList, selected @@ -1172,7 +1182,10 @@ def getCounters( except Exception as x: return S_ERROR(DErrno.EMYSQL, x) - cmd = f"SELECT {attrNames}, COUNT(*) FROM {table} {cond} GROUP BY {attrNames} ORDER BY {attrNames}" + # inner_join can be provided by higher level DB helpers to speed up large IN lists + # by joining on a temporary in-memory table. It should either be an empty string + # or start with a space and contain a complete JOIN clause. + cmd = f"SELECT {attrNames}, COUNT(*) FROM {table}{inner_join} {cond} GROUP BY {attrNames} ORDER BY {attrNames}" res = self._query(cmd, conn=connection) if not res["OK"]: return res diff --git a/src/DIRAC/TransformationSystem/DB/TransformationDB.py b/src/DIRAC/TransformationSystem/DB/TransformationDB.py index b5a1da10b04..5374502ed3d 100755 --- a/src/DIRAC/TransformationSystem/DB/TransformationDB.py +++ b/src/DIRAC/TransformationSystem/DB/TransformationDB.py @@ -344,6 +344,73 @@ def getTransformations( return resultList + def getCounters( + self, + table, + attrList, + condDict, + older=None, + newer=None, + timeStamp=None, + connection=False, + greater=None, + smaller=None, + ): + """Optimized getCounters override. + + For large lists of TransformationID values (length > TMP_TABLE_JOIN_LIMIT), we + create an in-memory temporary table and JOIN on it instead of relying on a + potentially very large IN (...) clause. This mirrors the optimization used in + getTransformations. + + Parameters mirror parent MySQL.getCounters except we build an internal + inner_join clause transparently. + """ + # Ensure we have a connection object + connection = self.__getConnection(connection) + + # Work on a copy so we do not mutate caller's dictionary + localCondDict = dict(condDict) if condDict else {} + join_query = "" + try: + if ( + "TransformationID" in localCondDict + and isinstance(localCondDict["TransformationID"], list) + and len(localCondDict["TransformationID"]) > TMP_TABLE_JOIN_LIMIT + ): + transIDs = localCondDict.pop("TransformationID") + # Create temporary table + sqlCmd = "CREATE TEMPORARY TABLE to_query_TransformationIDs (TransID INTEGER NOT NULL, PRIMARY KEY (TransID)) ENGINE=MEMORY;" + res = self._update(sqlCmd, conn=connection) + if not res["OK"]: + return res + join_query = " JOIN to_query_TransformationIDs t ON TransformationID = t.TransID" + # Bulk insert IDs + sqlCmd = "INSERT INTO to_query_TransformationIDs (TransID) VALUES ( %s )" + res = self._updatemany(sqlCmd, [(tid,) for tid in transIDs], conn=connection) + if not res["OK"]: + return res + + # Delegate to parent with inner_join parameter + res = super().getCounters( + table, + attrList, + localCondDict, + older=older, + newer=newer, + timeStamp=timeStamp, + connection=connection, + greater=greater, + smaller=smaller, + inner_join=join_query, + ) + finally: + if join_query: + # Drop temp table + self._update("DROP TEMPORARY TABLE to_query_TransformationIDs", conn=connection) + + return res + def getTransformation(self, transName, extraParams=False, connection=False): """Get Transformation definition and parameters of Transformation identified by TransformationID""" res = self._getConnectionTransID(connection, transName)