Skip to content

Commit 2d61af5

Browse files
committed
fix intervalmap
1 parent 9229bbc commit 2d61af5

1 file changed

Lines changed: 35 additions & 15 deletions

File tree

src/util.h

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -287,46 +287,49 @@ class IntervalMap
287287
void set_interval(K begin, K end, const V v) {
288288
if (begin >= end) return;
289289

290-
// get interval that `begin` intersects with (inclusive)
291-
iter begin_intersect = --my_map.upper_bound(begin);
292-
293-
// get interval that `end` intersects with (inclusive)
290+
// get end intersector (inclusive)
294291
iter end_intersect = --my_map.upper_bound(end);
295292

296-
// if required, insert at start
297-
iter inserted_start = my_map.end();
298-
if (begin_intersect->second != v) {
299-
inserted_start = my_map.insert_or_assign(begin_intersect, begin, v);
300-
}
301-
302293
// if required, insert at end
303294
iter inserted_end = my_map.end();
304295
if (end_intersect->second != v) {
305296
inserted_end = my_map.insert_or_assign(end_intersect, end, end_intersect->second);
306297
}
307298

299+
// get begin intersector (inclusive)
300+
iter begin_intersect = --my_map.upper_bound(begin);
301+
302+
// if required, insert at start
303+
iter inserted_start = my_map.end();
304+
if (begin_intersect->second != v) {
305+
inserted_start = my_map.insert_or_assign(begin_intersect, begin, v);
306+
}
307+
308308
// delete everyone inside
309309
iter del_start = inserted_start != my_map.end() ? inserted_start : begin_intersect;
310-
if (del_start->first == begin) {
310+
if (del_start->first < begin || (del_start->first == begin && std::prev(del_start)->second != v)) {
311311
del_start++;
312312
}
313313

314314
iter del_end = inserted_end != my_map.end() ? inserted_end : end_intersect;
315+
if (del_end != my_map.end() && del_end->first == end && std::next(del_end) != my_map.end() && del_end->second == v) {
316+
del_end++;
317+
}
315318

316-
if (del_start != my_map.end()) {
319+
if (del_start != my_map.end() && del_start->first < del_end->first) {
317320
my_map.erase(del_start, del_end);
318321
}
319322
}
320323

321324
// iterator which traverses elements in sorted order (smallest to largest)
322325
// O(1)
323-
constexpr inline iter &begin() {
326+
constexpr inline auto &begin() {
324327
return my_map.begin();
325328
}
326329

327330
// end of elements
328331
// O(1)
329-
constexpr inline iter &end() {
332+
constexpr inline auto &end() {
330333
return my_map.end();
331334
}
332335

@@ -338,7 +341,12 @@ class IntervalMap
338341

339342
// get value at key `k`
340343
// O(log N)
341-
const V operator[](K const& k) {
344+
const inline V& operator[](K const& k) const {
345+
return (--my_map.upper_bound(k))->second;
346+
}
347+
348+
// don't return reference because we don't want to allow map[whatever] = value as it would edit the next-earliest value instead of inserting a new element. (TODO: write better.)
349+
inline V operator[](K const& k) {
342350
return (--my_map.upper_bound(k))->second;
343351
}
344352

@@ -347,6 +355,18 @@ class IntervalMap
347355
my_map.clear();
348356
my_map.insert(my_map.end(), { std::numeric_limits<K>::lowest(), v });
349357
}
358+
359+
// get num intervals overall
360+
// always at least 1
361+
constexpr int num_intervals() {
362+
return my_map.size();
363+
}
364+
365+
// get number of intervals in a range
366+
// UNTESTED
367+
constexpr int num_intervals(const K& start, const K& end) {
368+
return get_interval(end) - get_interval(start);
369+
}
350370
};
351371

352372
#endif // __UTIL_H__

0 commit comments

Comments
 (0)