123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- // SPDX-License-Identifier: GPL-2.0-only
- // Copyright (C) 2019-2020 Arm Ltd.
- #include <linux/compiler.h>
- #include <linux/kasan-checks.h>
- #include <linux/kernel.h>
- #include <net/checksum.h>
- /* Looks dumb, but generates nice-ish code */
- static u64 accumulate(u64 sum, u64 data)
- {
- __uint128_t tmp = (__uint128_t)sum + data;
- return tmp + (tmp >> 64);
- }
- /*
- * We over-read the buffer and this makes KASAN unhappy. Instead, disable
- * instrumentation and call kasan explicitly.
- */
- unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
- {
- unsigned int offset, shift, sum;
- const u64 *ptr;
- u64 data, sum64 = 0;
- if (unlikely(len <= 0))
- return 0;
- offset = (unsigned long)buff & 7;
- /*
- * This is to all intents and purposes safe, since rounding down cannot
- * result in a different page or cache line being accessed, and @buff
- * should absolutely not be pointing to anything read-sensitive. We do,
- * however, have to be careful not to piss off KASAN, which means using
- * unchecked reads to accommodate the head and tail, for which we'll
- * compensate with an explicit check up-front.
- */
- kasan_check_read(buff, len);
- ptr = (u64 *)(buff - offset);
- len = len + offset - 8;
- /*
- * Head: zero out any excess leading bytes. Shifting back by the same
- * amount should be at least as fast as any other way of handling the
- * odd/even alignment, and means we can ignore it until the very end.
- */
- shift = offset * 8;
- data = *ptr++;
- #ifdef __LITTLE_ENDIAN
- data = (data >> shift) << shift;
- #else
- data = (data << shift) >> shift;
- #endif
- /*
- * Body: straightforward aligned loads from here on (the paired loads
- * underlying the quadword type still only need dword alignment). The
- * main loop strictly excludes the tail, so the second loop will always
- * run at least once.
- */
- while (unlikely(len > 64)) {
- __uint128_t tmp1, tmp2, tmp3, tmp4;
- tmp1 = *(__uint128_t *)ptr;
- tmp2 = *(__uint128_t *)(ptr + 2);
- tmp3 = *(__uint128_t *)(ptr + 4);
- tmp4 = *(__uint128_t *)(ptr + 6);
- len -= 64;
- ptr += 8;
- /* This is the "don't dump the carry flag into a GPR" idiom */
- tmp1 += (tmp1 >> 64) | (tmp1 << 64);
- tmp2 += (tmp2 >> 64) | (tmp2 << 64);
- tmp3 += (tmp3 >> 64) | (tmp3 << 64);
- tmp4 += (tmp4 >> 64) | (tmp4 << 64);
- tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
- tmp1 += (tmp1 >> 64) | (tmp1 << 64);
- tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
- tmp3 += (tmp3 >> 64) | (tmp3 << 64);
- tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
- tmp1 += (tmp1 >> 64) | (tmp1 << 64);
- tmp1 = ((tmp1 >> 64) << 64) | sum64;
- tmp1 += (tmp1 >> 64) | (tmp1 << 64);
- sum64 = tmp1 >> 64;
- }
- while (len > 8) {
- __uint128_t tmp;
- sum64 = accumulate(sum64, data);
- tmp = *(__uint128_t *)ptr;
- len -= 16;
- ptr += 2;
- #ifdef __LITTLE_ENDIAN
- data = tmp >> 64;
- sum64 = accumulate(sum64, tmp);
- #else
- data = tmp;
- sum64 = accumulate(sum64, tmp >> 64);
- #endif
- }
- if (len > 0) {
- sum64 = accumulate(sum64, data);
- data = *ptr;
- len -= 8;
- }
- /*
- * Tail: zero any over-read bytes similarly to the head, again
- * preserving odd/even alignment.
- */
- shift = len * -8;
- #ifdef __LITTLE_ENDIAN
- data = (data << shift) >> shift;
- #else
- data = (data >> shift) << shift;
- #endif
- sum64 = accumulate(sum64, data);
- /* Finally, folding */
- sum64 += (sum64 >> 32) | (sum64 << 32);
- sum = sum64 >> 32;
- sum += (sum >> 16) | (sum << 16);
- if (offset & 1)
- return (u16)swab32(sum);
- return sum >> 16;
- }
- __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
- const struct in6_addr *daddr,
- __u32 len, __u8 proto, __wsum csum)
- {
- __uint128_t src, dst;
- u64 sum = (__force u64)csum;
- src = *(const __uint128_t *)saddr->s6_addr;
- dst = *(const __uint128_t *)daddr->s6_addr;
- sum += (__force u32)htonl(len);
- #ifdef __LITTLE_ENDIAN
- sum += (u32)proto << 24;
- #else
- sum += proto;
- #endif
- src += (src >> 64) | (src << 64);
- dst += (dst >> 64) | (dst << 64);
- sum = accumulate(sum, src >> 64);
- sum = accumulate(sum, dst >> 64);
- sum += ((sum >> 32) | (sum << 32));
- return csum_fold((__force __wsum)(sum >> 32));
- }
- EXPORT_SYMBOL(csum_ipv6_magic);
|