ratelimiter.c 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. // SPDX-License-Identifier: GPL-2.0
  2. /*
  3. * Copyright (C) 2015-2019 Jason A. Donenfeld <[email protected]>. All Rights Reserved.
  4. */
  5. #include "ratelimiter.h"
  6. #include <linux/siphash.h>
  7. #include <linux/mm.h>
  8. #include <linux/slab.h>
  9. #include <net/ip.h>
  10. static struct kmem_cache *entry_cache;
  11. static hsiphash_key_t key;
  12. static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
  13. static DEFINE_MUTEX(init_lock);
  14. static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
  15. static atomic_t total_entries = ATOMIC_INIT(0);
  16. static unsigned int max_entries, table_size;
  17. static void wg_ratelimiter_gc_entries(struct work_struct *);
  18. static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
  19. static struct hlist_head *table_v4;
  20. #if IS_ENABLED(CONFIG_IPV6)
  21. static struct hlist_head *table_v6;
  22. #endif
  23. struct ratelimiter_entry {
  24. u64 last_time_ns, tokens, ip;
  25. void *net;
  26. spinlock_t lock;
  27. struct hlist_node hash;
  28. struct rcu_head rcu;
  29. };
  30. enum {
  31. PACKETS_PER_SECOND = 20,
  32. PACKETS_BURSTABLE = 5,
  33. PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
  34. TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
  35. };
  36. static void entry_free(struct rcu_head *rcu)
  37. {
  38. kmem_cache_free(entry_cache,
  39. container_of(rcu, struct ratelimiter_entry, rcu));
  40. atomic_dec(&total_entries);
  41. }
  42. static void entry_uninit(struct ratelimiter_entry *entry)
  43. {
  44. hlist_del_rcu(&entry->hash);
  45. call_rcu(&entry->rcu, entry_free);
  46. }
  47. /* Calling this function with a NULL work uninits all entries. */
  48. static void wg_ratelimiter_gc_entries(struct work_struct *work)
  49. {
  50. const u64 now = ktime_get_coarse_boottime_ns();
  51. struct ratelimiter_entry *entry;
  52. struct hlist_node *temp;
  53. unsigned int i;
  54. for (i = 0; i < table_size; ++i) {
  55. spin_lock(&table_lock);
  56. hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
  57. if (unlikely(!work) ||
  58. now - entry->last_time_ns > NSEC_PER_SEC)
  59. entry_uninit(entry);
  60. }
  61. #if IS_ENABLED(CONFIG_IPV6)
  62. hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
  63. if (unlikely(!work) ||
  64. now - entry->last_time_ns > NSEC_PER_SEC)
  65. entry_uninit(entry);
  66. }
  67. #endif
  68. spin_unlock(&table_lock);
  69. if (likely(work))
  70. cond_resched();
  71. }
  72. if (likely(work))
  73. queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
  74. }
  75. bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
  76. {
  77. /* We only take the bottom half of the net pointer, so that we can hash
  78. * 3 words in the end. This way, siphash's len param fits into the final
  79. * u32, and we don't incur an extra round.
  80. */
  81. const u32 net_word = (unsigned long)net;
  82. struct ratelimiter_entry *entry;
  83. struct hlist_head *bucket;
  84. u64 ip;
  85. if (skb->protocol == htons(ETH_P_IP)) {
  86. ip = (u64 __force)ip_hdr(skb)->saddr;
  87. bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
  88. (table_size - 1)];
  89. }
  90. #if IS_ENABLED(CONFIG_IPV6)
  91. else if (skb->protocol == htons(ETH_P_IPV6)) {
  92. /* Only use 64 bits, so as to ratelimit the whole /64. */
  93. memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
  94. bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
  95. (table_size - 1)];
  96. }
  97. #endif
  98. else
  99. return false;
  100. rcu_read_lock();
  101. hlist_for_each_entry_rcu(entry, bucket, hash) {
  102. if (entry->net == net && entry->ip == ip) {
  103. u64 now, tokens;
  104. bool ret;
  105. /* Quasi-inspired by nft_limit.c, but this is actually a
  106. * slightly different algorithm. Namely, we incorporate
  107. * the burst as part of the maximum tokens, rather than
  108. * as part of the rate.
  109. */
  110. spin_lock(&entry->lock);
  111. now = ktime_get_coarse_boottime_ns();
  112. tokens = min_t(u64, TOKEN_MAX,
  113. entry->tokens + now -
  114. entry->last_time_ns);
  115. entry->last_time_ns = now;
  116. ret = tokens >= PACKET_COST;
  117. entry->tokens = ret ? tokens - PACKET_COST : tokens;
  118. spin_unlock(&entry->lock);
  119. rcu_read_unlock();
  120. return ret;
  121. }
  122. }
  123. rcu_read_unlock();
  124. if (atomic_inc_return(&total_entries) > max_entries)
  125. goto err_oom;
  126. entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
  127. if (unlikely(!entry))
  128. goto err_oom;
  129. entry->net = net;
  130. entry->ip = ip;
  131. INIT_HLIST_NODE(&entry->hash);
  132. spin_lock_init(&entry->lock);
  133. entry->last_time_ns = ktime_get_coarse_boottime_ns();
  134. entry->tokens = TOKEN_MAX - PACKET_COST;
  135. spin_lock(&table_lock);
  136. hlist_add_head_rcu(&entry->hash, bucket);
  137. spin_unlock(&table_lock);
  138. return true;
  139. err_oom:
  140. atomic_dec(&total_entries);
  141. return false;
  142. }
  143. int wg_ratelimiter_init(void)
  144. {
  145. mutex_lock(&init_lock);
  146. if (++init_refcnt != 1)
  147. goto out;
  148. entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
  149. if (!entry_cache)
  150. goto err;
  151. /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
  152. * but what it shares in common is that it uses a massive hashtable. So,
  153. * we borrow their wisdom about good table sizes on different systems
  154. * dependent on RAM. This calculation here comes from there.
  155. */
  156. table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
  157. max_t(unsigned long, 16, roundup_pow_of_two(
  158. (totalram_pages() << PAGE_SHIFT) /
  159. (1U << 14) / sizeof(struct hlist_head)));
  160. max_entries = table_size * 8;
  161. table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
  162. if (unlikely(!table_v4))
  163. goto err_kmemcache;
  164. #if IS_ENABLED(CONFIG_IPV6)
  165. table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
  166. if (unlikely(!table_v6)) {
  167. kvfree(table_v4);
  168. goto err_kmemcache;
  169. }
  170. #endif
  171. queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
  172. get_random_bytes(&key, sizeof(key));
  173. out:
  174. mutex_unlock(&init_lock);
  175. return 0;
  176. err_kmemcache:
  177. kmem_cache_destroy(entry_cache);
  178. err:
  179. --init_refcnt;
  180. mutex_unlock(&init_lock);
  181. return -ENOMEM;
  182. }
  183. void wg_ratelimiter_uninit(void)
  184. {
  185. mutex_lock(&init_lock);
  186. if (!init_refcnt || --init_refcnt)
  187. goto out;
  188. cancel_delayed_work_sync(&gc_work);
  189. wg_ratelimiter_gc_entries(NULL);
  190. rcu_barrier();
  191. kvfree(table_v4);
  192. #if IS_ENABLED(CONFIG_IPV6)
  193. kvfree(table_v6);
  194. #endif
  195. kmem_cache_destroy(entry_cache);
  196. out:
  197. mutex_unlock(&init_lock);
  198. }
  199. #include "selftest/ratelimiter.c"