sm4-neon-glue.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. /* SPDX-License-Identifier: GPL-2.0-or-later */
  2. /*
  3. * SM4 Cipher Algorithm, using ARMv8 NEON
  4. * as specified in
  5. * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
  6. *
  7. * Copyright (C) 2022, Alibaba Group.
  8. * Copyright (C) 2022 Tianjia Zhang <[email protected]>
  9. */
  10. #include <linux/module.h>
  11. #include <linux/crypto.h>
  12. #include <linux/kernel.h>
  13. #include <linux/cpufeature.h>
  14. #include <asm/neon.h>
  15. #include <asm/simd.h>
  16. #include <crypto/internal/simd.h>
  17. #include <crypto/internal/skcipher.h>
  18. #include <crypto/sm4.h>
  19. #define BYTES2BLKS(nbytes) ((nbytes) >> 4)
  20. #define BYTES2BLK8(nbytes) (((nbytes) >> 4) & ~(8 - 1))
  21. asmlinkage void sm4_neon_crypt_blk1_8(const u32 *rkey, u8 *dst, const u8 *src,
  22. unsigned int nblks);
  23. asmlinkage void sm4_neon_crypt_blk8(const u32 *rkey, u8 *dst, const u8 *src,
  24. unsigned int nblks);
  25. asmlinkage void sm4_neon_cbc_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
  26. u8 *iv, unsigned int nblks);
  27. asmlinkage void sm4_neon_cfb_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
  28. u8 *iv, unsigned int nblks);
  29. asmlinkage void sm4_neon_ctr_enc_blk8(const u32 *rkey, u8 *dst, const u8 *src,
  30. u8 *iv, unsigned int nblks);
  31. static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
  32. unsigned int key_len)
  33. {
  34. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  35. return sm4_expandkey(ctx, key, key_len);
  36. }
  37. static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
  38. {
  39. struct skcipher_walk walk;
  40. unsigned int nbytes;
  41. int err;
  42. err = skcipher_walk_virt(&walk, req, false);
  43. while ((nbytes = walk.nbytes) > 0) {
  44. const u8 *src = walk.src.virt.addr;
  45. u8 *dst = walk.dst.virt.addr;
  46. unsigned int nblks;
  47. kernel_neon_begin();
  48. nblks = BYTES2BLK8(nbytes);
  49. if (nblks) {
  50. sm4_neon_crypt_blk8(rkey, dst, src, nblks);
  51. dst += nblks * SM4_BLOCK_SIZE;
  52. src += nblks * SM4_BLOCK_SIZE;
  53. nbytes -= nblks * SM4_BLOCK_SIZE;
  54. }
  55. nblks = BYTES2BLKS(nbytes);
  56. if (nblks) {
  57. sm4_neon_crypt_blk1_8(rkey, dst, src, nblks);
  58. nbytes -= nblks * SM4_BLOCK_SIZE;
  59. }
  60. kernel_neon_end();
  61. err = skcipher_walk_done(&walk, nbytes);
  62. }
  63. return err;
  64. }
  65. static int sm4_ecb_encrypt(struct skcipher_request *req)
  66. {
  67. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  68. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  69. return sm4_ecb_do_crypt(req, ctx->rkey_enc);
  70. }
  71. static int sm4_ecb_decrypt(struct skcipher_request *req)
  72. {
  73. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  74. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  75. return sm4_ecb_do_crypt(req, ctx->rkey_dec);
  76. }
  77. static int sm4_cbc_encrypt(struct skcipher_request *req)
  78. {
  79. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  80. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  81. struct skcipher_walk walk;
  82. unsigned int nbytes;
  83. int err;
  84. err = skcipher_walk_virt(&walk, req, false);
  85. while ((nbytes = walk.nbytes) > 0) {
  86. const u8 *iv = walk.iv;
  87. const u8 *src = walk.src.virt.addr;
  88. u8 *dst = walk.dst.virt.addr;
  89. while (nbytes >= SM4_BLOCK_SIZE) {
  90. crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
  91. sm4_crypt_block(ctx->rkey_enc, dst, dst);
  92. iv = dst;
  93. src += SM4_BLOCK_SIZE;
  94. dst += SM4_BLOCK_SIZE;
  95. nbytes -= SM4_BLOCK_SIZE;
  96. }
  97. if (iv != walk.iv)
  98. memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
  99. err = skcipher_walk_done(&walk, nbytes);
  100. }
  101. return err;
  102. }
  103. static int sm4_cbc_decrypt(struct skcipher_request *req)
  104. {
  105. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  106. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  107. struct skcipher_walk walk;
  108. unsigned int nbytes;
  109. int err;
  110. err = skcipher_walk_virt(&walk, req, false);
  111. while ((nbytes = walk.nbytes) > 0) {
  112. const u8 *src = walk.src.virt.addr;
  113. u8 *dst = walk.dst.virt.addr;
  114. unsigned int nblks;
  115. kernel_neon_begin();
  116. nblks = BYTES2BLK8(nbytes);
  117. if (nblks) {
  118. sm4_neon_cbc_dec_blk8(ctx->rkey_dec, dst, src,
  119. walk.iv, nblks);
  120. dst += nblks * SM4_BLOCK_SIZE;
  121. src += nblks * SM4_BLOCK_SIZE;
  122. nbytes -= nblks * SM4_BLOCK_SIZE;
  123. }
  124. nblks = BYTES2BLKS(nbytes);
  125. if (nblks) {
  126. u8 keystream[SM4_BLOCK_SIZE * 8];
  127. u8 iv[SM4_BLOCK_SIZE];
  128. int i;
  129. sm4_neon_crypt_blk1_8(ctx->rkey_dec, keystream,
  130. src, nblks);
  131. src += ((int)nblks - 2) * SM4_BLOCK_SIZE;
  132. dst += (nblks - 1) * SM4_BLOCK_SIZE;
  133. memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
  134. for (i = nblks - 1; i > 0; i--) {
  135. crypto_xor_cpy(dst, src,
  136. &keystream[i * SM4_BLOCK_SIZE],
  137. SM4_BLOCK_SIZE);
  138. src -= SM4_BLOCK_SIZE;
  139. dst -= SM4_BLOCK_SIZE;
  140. }
  141. crypto_xor_cpy(dst, walk.iv,
  142. keystream, SM4_BLOCK_SIZE);
  143. memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
  144. nbytes -= nblks * SM4_BLOCK_SIZE;
  145. }
  146. kernel_neon_end();
  147. err = skcipher_walk_done(&walk, nbytes);
  148. }
  149. return err;
  150. }
  151. static int sm4_cfb_encrypt(struct skcipher_request *req)
  152. {
  153. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  154. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  155. struct skcipher_walk walk;
  156. unsigned int nbytes;
  157. int err;
  158. err = skcipher_walk_virt(&walk, req, false);
  159. while ((nbytes = walk.nbytes) > 0) {
  160. u8 keystream[SM4_BLOCK_SIZE];
  161. const u8 *iv = walk.iv;
  162. const u8 *src = walk.src.virt.addr;
  163. u8 *dst = walk.dst.virt.addr;
  164. while (nbytes >= SM4_BLOCK_SIZE) {
  165. sm4_crypt_block(ctx->rkey_enc, keystream, iv);
  166. crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
  167. iv = dst;
  168. src += SM4_BLOCK_SIZE;
  169. dst += SM4_BLOCK_SIZE;
  170. nbytes -= SM4_BLOCK_SIZE;
  171. }
  172. if (iv != walk.iv)
  173. memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
  174. /* tail */
  175. if (walk.nbytes == walk.total && nbytes > 0) {
  176. sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
  177. crypto_xor_cpy(dst, src, keystream, nbytes);
  178. nbytes = 0;
  179. }
  180. err = skcipher_walk_done(&walk, nbytes);
  181. }
  182. return err;
  183. }
  184. static int sm4_cfb_decrypt(struct skcipher_request *req)
  185. {
  186. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  187. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  188. struct skcipher_walk walk;
  189. unsigned int nbytes;
  190. int err;
  191. err = skcipher_walk_virt(&walk, req, false);
  192. while ((nbytes = walk.nbytes) > 0) {
  193. const u8 *src = walk.src.virt.addr;
  194. u8 *dst = walk.dst.virt.addr;
  195. unsigned int nblks;
  196. kernel_neon_begin();
  197. nblks = BYTES2BLK8(nbytes);
  198. if (nblks) {
  199. sm4_neon_cfb_dec_blk8(ctx->rkey_enc, dst, src,
  200. walk.iv, nblks);
  201. dst += nblks * SM4_BLOCK_SIZE;
  202. src += nblks * SM4_BLOCK_SIZE;
  203. nbytes -= nblks * SM4_BLOCK_SIZE;
  204. }
  205. nblks = BYTES2BLKS(nbytes);
  206. if (nblks) {
  207. u8 keystream[SM4_BLOCK_SIZE * 8];
  208. memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
  209. if (nblks > 1)
  210. memcpy(&keystream[SM4_BLOCK_SIZE], src,
  211. (nblks - 1) * SM4_BLOCK_SIZE);
  212. memcpy(walk.iv, src + (nblks - 1) * SM4_BLOCK_SIZE,
  213. SM4_BLOCK_SIZE);
  214. sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
  215. keystream, nblks);
  216. crypto_xor_cpy(dst, src, keystream,
  217. nblks * SM4_BLOCK_SIZE);
  218. dst += nblks * SM4_BLOCK_SIZE;
  219. src += nblks * SM4_BLOCK_SIZE;
  220. nbytes -= nblks * SM4_BLOCK_SIZE;
  221. }
  222. kernel_neon_end();
  223. /* tail */
  224. if (walk.nbytes == walk.total && nbytes > 0) {
  225. u8 keystream[SM4_BLOCK_SIZE];
  226. sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
  227. crypto_xor_cpy(dst, src, keystream, nbytes);
  228. nbytes = 0;
  229. }
  230. err = skcipher_walk_done(&walk, nbytes);
  231. }
  232. return err;
  233. }
  234. static int sm4_ctr_crypt(struct skcipher_request *req)
  235. {
  236. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  237. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  238. struct skcipher_walk walk;
  239. unsigned int nbytes;
  240. int err;
  241. err = skcipher_walk_virt(&walk, req, false);
  242. while ((nbytes = walk.nbytes) > 0) {
  243. const u8 *src = walk.src.virt.addr;
  244. u8 *dst = walk.dst.virt.addr;
  245. unsigned int nblks;
  246. kernel_neon_begin();
  247. nblks = BYTES2BLK8(nbytes);
  248. if (nblks) {
  249. sm4_neon_ctr_enc_blk8(ctx->rkey_enc, dst, src,
  250. walk.iv, nblks);
  251. dst += nblks * SM4_BLOCK_SIZE;
  252. src += nblks * SM4_BLOCK_SIZE;
  253. nbytes -= nblks * SM4_BLOCK_SIZE;
  254. }
  255. nblks = BYTES2BLKS(nbytes);
  256. if (nblks) {
  257. u8 keystream[SM4_BLOCK_SIZE * 8];
  258. int i;
  259. for (i = 0; i < nblks; i++) {
  260. memcpy(&keystream[i * SM4_BLOCK_SIZE],
  261. walk.iv, SM4_BLOCK_SIZE);
  262. crypto_inc(walk.iv, SM4_BLOCK_SIZE);
  263. }
  264. sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
  265. keystream, nblks);
  266. crypto_xor_cpy(dst, src, keystream,
  267. nblks * SM4_BLOCK_SIZE);
  268. dst += nblks * SM4_BLOCK_SIZE;
  269. src += nblks * SM4_BLOCK_SIZE;
  270. nbytes -= nblks * SM4_BLOCK_SIZE;
  271. }
  272. kernel_neon_end();
  273. /* tail */
  274. if (walk.nbytes == walk.total && nbytes > 0) {
  275. u8 keystream[SM4_BLOCK_SIZE];
  276. sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
  277. crypto_inc(walk.iv, SM4_BLOCK_SIZE);
  278. crypto_xor_cpy(dst, src, keystream, nbytes);
  279. nbytes = 0;
  280. }
  281. err = skcipher_walk_done(&walk, nbytes);
  282. }
  283. return err;
  284. }
  285. static struct skcipher_alg sm4_algs[] = {
  286. {
  287. .base = {
  288. .cra_name = "ecb(sm4)",
  289. .cra_driver_name = "ecb-sm4-neon",
  290. .cra_priority = 200,
  291. .cra_blocksize = SM4_BLOCK_SIZE,
  292. .cra_ctxsize = sizeof(struct sm4_ctx),
  293. .cra_module = THIS_MODULE,
  294. },
  295. .min_keysize = SM4_KEY_SIZE,
  296. .max_keysize = SM4_KEY_SIZE,
  297. .setkey = sm4_setkey,
  298. .encrypt = sm4_ecb_encrypt,
  299. .decrypt = sm4_ecb_decrypt,
  300. }, {
  301. .base = {
  302. .cra_name = "cbc(sm4)",
  303. .cra_driver_name = "cbc-sm4-neon",
  304. .cra_priority = 200,
  305. .cra_blocksize = SM4_BLOCK_SIZE,
  306. .cra_ctxsize = sizeof(struct sm4_ctx),
  307. .cra_module = THIS_MODULE,
  308. },
  309. .min_keysize = SM4_KEY_SIZE,
  310. .max_keysize = SM4_KEY_SIZE,
  311. .ivsize = SM4_BLOCK_SIZE,
  312. .setkey = sm4_setkey,
  313. .encrypt = sm4_cbc_encrypt,
  314. .decrypt = sm4_cbc_decrypt,
  315. }, {
  316. .base = {
  317. .cra_name = "cfb(sm4)",
  318. .cra_driver_name = "cfb-sm4-neon",
  319. .cra_priority = 200,
  320. .cra_blocksize = 1,
  321. .cra_ctxsize = sizeof(struct sm4_ctx),
  322. .cra_module = THIS_MODULE,
  323. },
  324. .min_keysize = SM4_KEY_SIZE,
  325. .max_keysize = SM4_KEY_SIZE,
  326. .ivsize = SM4_BLOCK_SIZE,
  327. .chunksize = SM4_BLOCK_SIZE,
  328. .setkey = sm4_setkey,
  329. .encrypt = sm4_cfb_encrypt,
  330. .decrypt = sm4_cfb_decrypt,
  331. }, {
  332. .base = {
  333. .cra_name = "ctr(sm4)",
  334. .cra_driver_name = "ctr-sm4-neon",
  335. .cra_priority = 200,
  336. .cra_blocksize = 1,
  337. .cra_ctxsize = sizeof(struct sm4_ctx),
  338. .cra_module = THIS_MODULE,
  339. },
  340. .min_keysize = SM4_KEY_SIZE,
  341. .max_keysize = SM4_KEY_SIZE,
  342. .ivsize = SM4_BLOCK_SIZE,
  343. .chunksize = SM4_BLOCK_SIZE,
  344. .setkey = sm4_setkey,
  345. .encrypt = sm4_ctr_crypt,
  346. .decrypt = sm4_ctr_crypt,
  347. }
  348. };
  349. static int __init sm4_init(void)
  350. {
  351. return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
  352. }
  353. static void __exit sm4_exit(void)
  354. {
  355. crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
  356. }
  357. module_init(sm4_init);
  358. module_exit(sm4_exit);
  359. MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 NEON");
  360. MODULE_ALIAS_CRYPTO("sm4-neon");
  361. MODULE_ALIAS_CRYPTO("sm4");
  362. MODULE_ALIAS_CRYPTO("ecb(sm4)");
  363. MODULE_ALIAS_CRYPTO("cbc(sm4)");
  364. MODULE_ALIAS_CRYPTO("cfb(sm4)");
  365. MODULE_ALIAS_CRYPTO("ctr(sm4)");
  366. MODULE_AUTHOR("Tianjia Zhang <[email protected]>");
  367. MODULE_LICENSE("GPL v2");