AWS SDK for C++

AWS SDK for C++ Version 1.11.606

Loading...
Searching...
No Matches
DefaultRateLimiter.h
1
6#pragma once
7
8#include <aws/core/Core_EXPORTS.h>
9#include <aws/core/utils/ratelimiter/RateLimiterInterface.h>
10
11#include <algorithm>
12#include <functional>
13#include <mutex>
14#include <thread>
15
16namespace Aws {
17namespace Utils {
18namespace RateLimits {
22template <typename CLOCK = std::chrono::steady_clock, typename DUR = std::chrono::seconds, bool RENORMALIZE_RATE_CHANGES = true>
24 public:
26
27 using InternalTimePointType = std::chrono::time_point<CLOCK>;
29
33 DefaultRateLimiter(int64_t maxRate, ElapsedTimeFunctionType elapsedTimeFunction = CLOCK::now)
34 : m_elapsedTimeFunction(elapsedTimeFunction),
35 m_maxRate(0),
36 m_accumulatorLock(),
37 m_accumulator(0),
38 m_accumulatorFraction(0),
39 m_accumulatorUpdated(),
40 m_replenishNumerator(0),
41 m_replenishDenominator(0),
42 m_delayNumerator(0),
43 m_delayDenominator(0) {
44 // verify we're not going to divide by zero due to goofy type parameterization
45 static_assert(DUR::period::num > 0, "Rate duration must have positive numerator");
46 static_assert(DUR::period::den > 0, "Rate duration must have positive denominator");
47 static_assert(CLOCK::duration::period::num > 0, "RateLimiter clock duration must have positive numerator");
48 static_assert(CLOCK::duration::period::den > 0, "RateLimiter clock duration must have positive denominator");
49
50 DefaultRateLimiter::SetRate(maxRate, true);
51 }
52
53 virtual ~DefaultRateLimiter() = default;
54
58 virtual DelayType ApplyCost(int64_t cost) override {
59 std::lock_guard<std::recursive_mutex> lock(m_accumulatorLock);
60
61 auto now = m_elapsedTimeFunction();
62 auto elapsedTime = (now - m_accumulatorUpdated).count();
63
64 // check for overflow case
65 if (m_replenishNumerator != 0 && elapsedTime > 0 &&
66 (elapsedTime > std::numeric_limits<int64_t>::max() / m_replenishNumerator ||
67 m_accumulatorFraction > std::numeric_limits<int64_t>::max() - (elapsedTime * m_replenishNumerator))) {
68 m_accumulator = m_maxRate;
69 m_accumulatorFraction = 0;
70 } else {
71 // replenish the accumulator based on how much time has passed
72 auto temp = elapsedTime * m_replenishNumerator + m_accumulatorFraction;
73 m_accumulator += temp / m_replenishDenominator;
74 m_accumulatorFraction = temp % m_replenishDenominator;
75
76 // the accumulator is capped based on the maximum rate
77 m_accumulator = (std::min)(m_accumulator, m_maxRate);
78 if (m_accumulator == m_maxRate) {
79 m_accumulatorFraction = 0;
80 }
81 }
82
83 // if the accumulator is still negative, then we'll have to wait
84 DelayType delay(0);
85 if (m_accumulator < 0) {
86 delay = DelayType(-m_accumulator * m_delayDenominator / m_delayNumerator);
87 }
88
89 // apply the cost to the accumulator after the delay has been calculated; the next call will end up paying for our cost
90 m_accumulator -= cost;
91 m_accumulatorUpdated = now;
92
93 return delay;
94 }
95
99 virtual void ApplyAndPayForCost(int64_t cost) override {
100 auto costInMilliseconds = ApplyCost(cost);
101 if (costInMilliseconds.count() > 0) {
102 std::this_thread::sleep_for(costInMilliseconds);
103 }
104 }
105
109 virtual void SetRate(int64_t rate, bool resetAccumulator = false) override {
110 std::lock_guard<std::recursive_mutex> lock(m_accumulatorLock);
111
112 // rate must always be positive
113 rate = (std::max)(static_cast<int64_t>(1), rate);
114
115 if (resetAccumulator) {
116 m_accumulator = rate;
117 m_accumulatorFraction = 0;
118 m_accumulatorUpdated = m_elapsedTimeFunction();
119 } else {
120 // sync the accumulator to current time
121 ApplyCost(0); // this call is why we need a recursive mutex
122
123 if (ShouldRenormalizeAccumulatorOnRateChange()) {
124 // now renormalize the accumulator and its fractional part against the new rate
125 // the idea here is we want to preserve the desired wait based on the previous rate
126 //
127 // As an example:
128 // Say we had a rate of 100/s and our accumulator was -500 (ie the next ApplyCost would incur a 5 second delay)
129 // If we change the rate to 1000/s and want to preserve that delay, we need to scale the accumulator to -5000
130 m_accumulator = m_accumulator * rate / m_maxRate;
131 m_accumulatorFraction = m_accumulatorFraction * rate / m_maxRate;
132 }
133 }
134
135 m_maxRate = rate;
136
137 // Helper constants that represent the amount replenished per CLOCK time period; use the gcd to reduce them in order to try and minimize
138 // the chance of integer overflow
139 m_replenishNumerator = m_maxRate * DUR::period::den * CLOCK::duration::period::num;
140 m_replenishDenominator = DUR::period::num * CLOCK::duration::period::den;
141 auto gcd = ComputeGCD(m_replenishNumerator, m_replenishDenominator);
142 m_replenishNumerator /= gcd;
143 m_replenishDenominator /= gcd;
144
145 // Helper constants that represent the delay per unit of costAccumulator; use the gcd to reduce them in order to try and minimize the
146 // chance of integer overflow
147 m_delayNumerator = m_maxRate * DelayType::period::num * DUR::period::den;
148 m_delayDenominator = DelayType::period::den * DUR::period::num;
149 gcd = ComputeGCD(m_delayNumerator, m_delayDenominator);
150 m_delayNumerator /= gcd;
151 m_delayDenominator /= gcd;
152 }
153
154 private:
155 int64_t ComputeGCD(int64_t num1, int64_t num2) const {
156 // Euclid's
157 while (num2 != 0) {
158 int64_t rem = num1 % num2;
159 num1 = num2;
160 num2 = rem;
161 }
162
163 return num1;
164 }
165
166 bool ShouldRenormalizeAccumulatorOnRateChange() const { return RENORMALIZE_RATE_CHANGES; }
167
169 ElapsedTimeFunctionType m_elapsedTimeFunction;
170
172 int64_t m_maxRate;
173
175 std::recursive_mutex m_accumulatorLock;
176
179 int64_t m_accumulator;
180
183 int64_t m_accumulatorFraction;
184
186 InternalTimePointType m_accumulatorUpdated;
187
189 int64_t m_replenishNumerator;
190 int64_t m_replenishDenominator;
191 int64_t m_delayNumerator;
192 int64_t m_delayDenominator;
193};
194
195} // namespace RateLimits
196} // namespace Utils
197} // namespace Aws
virtual void SetRate(int64_t rate, bool resetAccumulator=false) override
virtual DelayType ApplyCost(int64_t cost) override
virtual void ApplyAndPayForCost(int64_t cost) override
DefaultRateLimiter(int64_t maxRate, ElapsedTimeFunctionType elapsedTimeFunction=CLOCK::now)
std::function< InternalTimePointType()> ElapsedTimeFunctionType
std::chrono::time_point< CLOCK > InternalTimePointType