Skip to content

Commit 01834ed

Browse files
updated other sketches
1 parent 388041f commit 01834ed

1 file changed

Lines changed: 134 additions & 19 deletions

File tree

asap-query-engine/src/precompute_engine/accumulator_factory.rs

Lines changed: 134 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -700,35 +700,37 @@ pub fn config_is_keyed(config: &AggregationConfig) -> bool {
700700

701701
/// Extract the KLL `k` parameter. Capital `"K"` takes precedence over lowercase
702702
/// `"k"` to match the convention used by the top-level aggregation type arms.
703-
fn kll_k_param(config: &AggregationConfig) -> u16 {
703+
fn kll_k_param(config: &AggregationConfig) -> Result<u16, String> {
704704
config
705705
.parameters
706706
.get("K")
707707
.or_else(|| config.parameters.get("k"))
708708
.and_then(|v| v.as_u64())
709709
.and_then(|v| u16::try_from(v).ok())
710-
.unwrap_or(200)
710+
.ok_or_else(|| "KLL config missing required parameter (tried: K, k)".to_string())
711711
}
712712

713713
/// Extract `(row_num, col_num)` for CMS / HydraKLL configs.
714-
fn cms_params(config: &AggregationConfig) -> (usize, usize) {
714+
fn cms_params(config: &AggregationConfig) -> Result<(usize, usize), String> {
715715
let row_num = config
716716
.parameters
717717
.get("row_num")
718718
.and_then(|v| v.as_u64())
719-
.unwrap_or(4) as usize;
719+
.ok_or_else(|| "CMS config missing required parameter: row_num".to_string())?
720+
as usize;
720721
let col_num = config
721722
.parameters
722723
.get("col_num")
723724
.and_then(|v| v.as_u64())
724-
.unwrap_or(1000) as usize;
725-
(row_num, col_num)
725+
.ok_or_else(|| "CMS config missing required parameter: col_num".to_string())?
726+
as usize;
727+
Ok((row_num, col_num))
726728
}
727729

728730
/// Extract `(row_num, col_num, k)` for HydraKLL configs.
729-
fn hydra_kll_params(config: &AggregationConfig) -> (usize, usize, u16) {
730-
let (row_num, col_num) = cms_params(config);
731-
(row_num, col_num, kll_k_param(config))
731+
fn hydra_kll_params(config: &AggregationConfig) -> Result<(usize, usize, u16), String> {
732+
let (row_num, col_num) = cms_params(config)?;
733+
Ok((row_num, col_num, kll_k_param(config)?))
732734
}
733735

734736
/// Extract `(row_num, col_num, heap_size)` for CountMinSketchWithHeap configs.
@@ -796,7 +798,8 @@ fn hll_precision_param(config: &AggregationConfig) -> u32 {
796798
/// Create an appropriate `AccumulatorUpdater` from an `AggregationConfig`.
797799
///
798800
/// Returns `Err` if the config is of a type that requires specific parameters
799-
/// (e.g. `CountMinSketchWithHeap`) but those parameters are absent or invalid.
801+
/// (e.g. `CountMinSketchWithHeap`, `CountMinSketch`, `HydraKLL`, KLL variants)
802+
/// but those parameters are absent or invalid.
800803
pub fn create_accumulator_updater(
801804
config: &AggregationConfig,
802805
) -> Result<Box<dyn AccumulatorUpdater>, String> {
@@ -809,7 +812,7 @@ pub fn create_accumulator_updater(
809812
"Max" | "max" => Ok(Box::new(MinMaxAccumulatorUpdater::new(true))),
810813
"Increase" | "increase" => Ok(Box::new(IncreaseAccumulatorUpdater::new())),
811814
"DatasketchesKLL" | "datasketches_kll" | "KLL" | "kll" => {
812-
Ok(Box::new(KllAccumulatorUpdater::new(kll_k_param(config))))
815+
Ok(Box::new(KllAccumulatorUpdater::new(kll_k_param(config)?)))
813816
}
814817
other => {
815818
tracing::warn!(
@@ -825,11 +828,11 @@ pub fn create_accumulator_updater(
825828
"Max" | "max" => Ok(Box::new(MultipleMinMaxAccumulatorUpdater::new(true))),
826829
"Increase" | "increase" => Ok(Box::new(MultipleIncreaseAccumulatorUpdater::new())),
827830
"CountMinSketch" | "count_min_sketch" | "CMS" | "cms" => {
828-
let (row_num, col_num) = cms_params(config);
831+
let (row_num, col_num) = cms_params(config)?;
829832
Ok(Box::new(CmsAccumulatorUpdater::new(row_num, col_num)))
830833
}
831834
"HydraKLL" | "hydra_kll" => {
832-
let (row_num, col_num, k) = hydra_kll_params(config);
835+
let (row_num, col_num, k) = hydra_kll_params(config)?;
833836
Ok(Box::new(HydraKllAccumulatorUpdater::new(
834837
row_num, col_num, k,
835838
)))
@@ -843,7 +846,7 @@ pub fn create_accumulator_updater(
843846
}
844847
},
845848
AggregationType::DatasketchesKLL => {
846-
Ok(Box::new(KllAccumulatorUpdater::new(kll_k_param(config))))
849+
Ok(Box::new(KllAccumulatorUpdater::new(kll_k_param(config)?)))
847850
}
848851
AggregationType::MultipleSum => Ok(Box::new(MultipleSumAccumulatorUpdater::new())),
849852
AggregationType::MultipleIncrease => {
@@ -858,7 +861,7 @@ pub fn create_accumulator_updater(
858861
))),
859862
AggregationType::Increase => Ok(Box::new(IncreaseAccumulatorUpdater::new())),
860863
AggregationType::CountMinSketch => {
861-
let (row_num, col_num) = cms_params(config);
864+
let (row_num, col_num) = cms_params(config)?;
862865
Ok(Box::new(CmsAccumulatorUpdater::new(row_num, col_num)))
863866
}
864867
AggregationType::CountMinSketchWithHeap => {
@@ -871,7 +874,7 @@ pub fn create_accumulator_updater(
871874
)))
872875
}
873876
AggregationType::HydraKLL => {
874-
let (row_num, col_num, k) = hydra_kll_params(config);
877+
let (row_num, col_num, k) = hydra_kll_params(config)?;
875878
Ok(Box::new(HydraKllAccumulatorUpdater::new(
876879
row_num, col_num, k,
877880
)))
@@ -1028,13 +1031,11 @@ mod tests {
10281031
)));
10291032
assert!(config_is_keyed(&make_config(AggregationType::HydraKLL, "")));
10301033

1031-
// Verify agreement with updater.is_keyed()
1034+
// Verify agreement with updater.is_keyed() for types that need no sketch params.
10321035
for (agg_type, sub_type) in &[
10331036
(AggregationType::SingleSubpopulation, "Sum"),
10341037
(AggregationType::MultipleSubpopulation, "Sum"),
10351038
(AggregationType::MultipleSum, ""),
1036-
(AggregationType::DatasketchesKLL, ""),
1037-
(AggregationType::CountMinSketch, ""),
10381039
] {
10391040
let config = make_config(*agg_type, sub_type);
10401041
let updater = create_accumulator_updater(&config).unwrap();
@@ -1045,6 +1046,45 @@ mod tests {
10451046
agg_type
10461047
);
10471048
}
1049+
1050+
// Sketch types require params — build configs with the required parameters.
1051+
let make_config_with_params =
1052+
|agg_type: AggregationType,
1053+
sub_type: &str,
1054+
params: std::collections::HashMap<String, serde_json::Value>| {
1055+
AggregationConfig::new(
1056+
1,
1057+
agg_type,
1058+
sub_type.to_string(),
1059+
params,
1060+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1061+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1062+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1063+
String::new(),
1064+
60,
1065+
0,
1066+
WindowType::Tumbling,
1067+
"m".to_string(),
1068+
"m".to_string(),
1069+
None,
1070+
None,
1071+
None,
1072+
None,
1073+
)
1074+
};
1075+
for (agg_type, sub_type, params) in [
1076+
(AggregationType::DatasketchesKLL, "", kll_params_required()),
1077+
(AggregationType::CountMinSketch, "", cms_params_required()),
1078+
] {
1079+
let config = make_config_with_params(agg_type, sub_type, params);
1080+
let updater = create_accumulator_updater(&config).unwrap();
1081+
assert_eq!(
1082+
config_is_keyed(&config),
1083+
updater.is_keyed(),
1084+
"config_is_keyed disagrees with updater.is_keyed() for type={:?}",
1085+
agg_type
1086+
);
1087+
}
10481088
}
10491089

10501090
// HLL: `AggregationType::HLL` must build `HllAccumulatorUpdater` (hashes samples
@@ -1241,6 +1281,81 @@ mod tests {
12411281
assert_eq!(kll.inner.k, 50, "k should be 50 from capital-K param");
12421282
}
12431283

1284+
#[test]
1285+
fn test_kll_missing_k_param_returns_err() {
1286+
use std::collections::HashMap;
1287+
let config = AggregationConfig::new(
1288+
1,
1289+
AggregationType::DatasketchesKLL,
1290+
String::new(),
1291+
HashMap::new(),
1292+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1293+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1294+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1295+
String::new(),
1296+
60,
1297+
0,
1298+
WindowType::Tumbling,
1299+
"m".to_string(),
1300+
"m".to_string(),
1301+
None,
1302+
None,
1303+
None,
1304+
None,
1305+
);
1306+
let err = create_accumulator_updater(&config)
1307+
.err()
1308+
.expect("expected Err for missing K param");
1309+
assert!(
1310+
err.contains("K"),
1311+
"error should mention the missing parameter name"
1312+
);
1313+
}
1314+
1315+
#[test]
1316+
fn test_cms_missing_params_returns_err() {
1317+
use std::collections::HashMap;
1318+
let config = AggregationConfig::new(
1319+
1,
1320+
AggregationType::CountMinSketch,
1321+
String::new(),
1322+
HashMap::new(),
1323+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1324+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1325+
promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]),
1326+
String::new(),
1327+
60,
1328+
0,
1329+
WindowType::Tumbling,
1330+
"m".to_string(),
1331+
"m".to_string(),
1332+
None,
1333+
None,
1334+
None,
1335+
None,
1336+
);
1337+
let err = create_accumulator_updater(&config)
1338+
.err()
1339+
.expect("expected Err for missing row_num param");
1340+
assert!(
1341+
err.contains("row_num"),
1342+
"error should mention the missing parameter name"
1343+
);
1344+
}
1345+
1346+
fn kll_params_required() -> std::collections::HashMap<String, serde_json::Value> {
1347+
let mut p = std::collections::HashMap::new();
1348+
p.insert("K".to_string(), serde_json::json!(200_u64));
1349+
p
1350+
}
1351+
1352+
fn cms_params_required() -> std::collections::HashMap<String, serde_json::Value> {
1353+
let mut p = std::collections::HashMap::new();
1354+
p.insert("row_num".to_string(), serde_json::json!(4_u64));
1355+
p.insert("col_num".to_string(), serde_json::json!(1000_u64));
1356+
p
1357+
}
1358+
12441359
fn cms_heap_params_required() -> std::collections::HashMap<String, serde_json::Value> {
12451360
let mut p = std::collections::HashMap::new();
12461361
p.insert("depth".to_string(), serde_json::json!(3_u64));

0 commit comments

Comments
 (0)