diff --git a/src/dsc.h b/src/dsc.h index ab2e68f..a244be0 100644 --- a/src/dsc.h +++ b/src/dsc.h @@ -315,6 +315,17 @@ public: } }; +// Comparison functor used by IdList and related classes +template +struct CompareId { + bool operator()(T const& lhs, T const& rhs) const { + return lhs.h.v < rhs.h.v; + } + bool operator()(T const& lhs, H rhs) const { + return lhs.h.v < rhs.v; + } +}; + // A list, where each element has an integer identifier. The list is kept // sorted by that identifier, and items can be looked up in log n time by // id. @@ -325,6 +336,8 @@ public: int n; int elemsAllocated; + using Compare = CompareId; + bool IsEmpty() const { return n == 0; } @@ -344,6 +357,31 @@ public: return t->h; } + T * LowerBound(T const& t) { + if(IsEmpty()) { + return nullptr; + } + auto it = std::lower_bound(begin(), end(), t, Compare()); + return it; + } + + T * LowerBound(H const& h) { + if(IsEmpty()) { + return nullptr; + } + auto it = std::lower_bound(begin(), end(), h, Compare()); + return it; + } + + int LowerBoundIndex(T const& t) { + if(IsEmpty()) { + return 0; + } + auto it = LowerBound(t); + auto idx = std::distance(begin(), it); + auto i = static_cast(idx); + return i; + } void ReserveMore(int howMuch) { if(n + howMuch > elemsAllocated) { elemsAllocated = n + howMuch; @@ -361,21 +399,12 @@ public: if(n >= elemsAllocated) { ReserveMore((elemsAllocated + 32)*2 - n); } - - int first = 0, last = n; - // We know that we must insert within the closed interval [first,last] - while(first != last) { - int mid = (first + last)/2; - H hm = elem[mid].h; + auto newIndex = LowerBoundIndex(*t); + if (newIndex < n) { + H hm = elem[newIndex].h; ssassert(hm.v != t->h.v, "Handle isn't unique"); - if(hm.v > t->h.v) { - last = mid; - } else if(hm.v < t->h.v) { - first = mid + 1; - } } - - int i = first; + int i = static_cast(newIndex); new(&elem[n]) T(); std::move_backward(elem + i, elem + n, elem + n + 1); elem[i] = *t; @@ -392,17 +421,10 @@ public: if(IsEmpty()) { return -1; } - int first = 0, last = n-1; - while(first <= last) { - int mid = (first + last)/2; - H hm = elem[mid].h; - if(hm.v > h.v) { - last = mid-1; // and first stays the same - } else if(hm.v < h.v) { - first = mid+1; // and last stays the same - } else { - return mid; - } + auto it = LowerBound(h); + auto idx = std::distance(begin(), it); + if (idx < n) { + return idx; } return -1; } @@ -411,19 +433,14 @@ public: if(IsEmpty()) { return nullptr; } - int first = 0, last = n-1; - while(first <= last) { - int mid = (first + last)/2; - H hm = elem[mid].h; - if(hm.v > h.v) { - last = mid-1; // and first stays the same - } else if(hm.v < h.v) { - first = mid+1; // and last stays the same - } else { - return &(elem[mid]); - } + auto it = LowerBound(h); + if (it == nullptr || it == end()) { + return nullptr; } - return NULL; + if (it->h.v == h.v) { + return it; + } + return nullptr; } T *First() {