sm4-ce-glue.c 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. /* SPDX-License-Identifier: GPL-2.0-or-later */
  2. /*
  3. * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions
  4. * as specified in
  5. * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
  6. *
  7. * Copyright (C) 2022, Alibaba Group.
  8. * Copyright (C) 2022 Tianjia Zhang <[email protected]>
  9. */
  10. #include <linux/module.h>
  11. #include <linux/crypto.h>
  12. #include <linux/kernel.h>
  13. #include <linux/cpufeature.h>
  14. #include <asm/neon.h>
  15. #include <asm/simd.h>
  16. #include <crypto/internal/simd.h>
  17. #include <crypto/internal/skcipher.h>
  18. #include <crypto/sm4.h>
  19. #define BYTES2BLKS(nbytes) ((nbytes) >> 4)
  20. asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec,
  21. const u32 *fk, const u32 *ck);
  22. asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src);
  23. asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src,
  24. unsigned int nblks);
  25. asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
  26. u8 *iv, unsigned int nblks);
  27. asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
  28. u8 *iv, unsigned int nblks);
  29. asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src,
  30. u8 *iv, unsigned int nblks);
  31. asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src,
  32. u8 *iv, unsigned int nblks);
  33. asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src,
  34. u8 *iv, unsigned int nblks);
  35. static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
  36. unsigned int key_len)
  37. {
  38. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  39. if (key_len != SM4_KEY_SIZE)
  40. return -EINVAL;
  41. sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
  42. crypto_sm4_fk, crypto_sm4_ck);
  43. return 0;
  44. }
  45. static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
  46. {
  47. struct skcipher_walk walk;
  48. unsigned int nbytes;
  49. int err;
  50. err = skcipher_walk_virt(&walk, req, false);
  51. while ((nbytes = walk.nbytes) > 0) {
  52. const u8 *src = walk.src.virt.addr;
  53. u8 *dst = walk.dst.virt.addr;
  54. unsigned int nblks;
  55. kernel_neon_begin();
  56. nblks = BYTES2BLKS(nbytes);
  57. if (nblks) {
  58. sm4_ce_crypt(rkey, dst, src, nblks);
  59. nbytes -= nblks * SM4_BLOCK_SIZE;
  60. }
  61. kernel_neon_end();
  62. err = skcipher_walk_done(&walk, nbytes);
  63. }
  64. return err;
  65. }
  66. static int sm4_ecb_encrypt(struct skcipher_request *req)
  67. {
  68. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  69. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  70. return sm4_ecb_do_crypt(req, ctx->rkey_enc);
  71. }
  72. static int sm4_ecb_decrypt(struct skcipher_request *req)
  73. {
  74. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  75. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  76. return sm4_ecb_do_crypt(req, ctx->rkey_dec);
  77. }
  78. static int sm4_cbc_encrypt(struct skcipher_request *req)
  79. {
  80. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  81. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  82. struct skcipher_walk walk;
  83. unsigned int nbytes;
  84. int err;
  85. err = skcipher_walk_virt(&walk, req, false);
  86. while ((nbytes = walk.nbytes) > 0) {
  87. const u8 *src = walk.src.virt.addr;
  88. u8 *dst = walk.dst.virt.addr;
  89. unsigned int nblks;
  90. kernel_neon_begin();
  91. nblks = BYTES2BLKS(nbytes);
  92. if (nblks) {
  93. sm4_ce_cbc_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
  94. nbytes -= nblks * SM4_BLOCK_SIZE;
  95. }
  96. kernel_neon_end();
  97. err = skcipher_walk_done(&walk, nbytes);
  98. }
  99. return err;
  100. }
  101. static int sm4_cbc_decrypt(struct skcipher_request *req)
  102. {
  103. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  104. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  105. struct skcipher_walk walk;
  106. unsigned int nbytes;
  107. int err;
  108. err = skcipher_walk_virt(&walk, req, false);
  109. while ((nbytes = walk.nbytes) > 0) {
  110. const u8 *src = walk.src.virt.addr;
  111. u8 *dst = walk.dst.virt.addr;
  112. unsigned int nblks;
  113. kernel_neon_begin();
  114. nblks = BYTES2BLKS(nbytes);
  115. if (nblks) {
  116. sm4_ce_cbc_dec(ctx->rkey_dec, dst, src, walk.iv, nblks);
  117. nbytes -= nblks * SM4_BLOCK_SIZE;
  118. }
  119. kernel_neon_end();
  120. err = skcipher_walk_done(&walk, nbytes);
  121. }
  122. return err;
  123. }
  124. static int sm4_cfb_encrypt(struct skcipher_request *req)
  125. {
  126. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  127. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  128. struct skcipher_walk walk;
  129. unsigned int nbytes;
  130. int err;
  131. err = skcipher_walk_virt(&walk, req, false);
  132. while ((nbytes = walk.nbytes) > 0) {
  133. const u8 *src = walk.src.virt.addr;
  134. u8 *dst = walk.dst.virt.addr;
  135. unsigned int nblks;
  136. kernel_neon_begin();
  137. nblks = BYTES2BLKS(nbytes);
  138. if (nblks) {
  139. sm4_ce_cfb_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
  140. dst += nblks * SM4_BLOCK_SIZE;
  141. src += nblks * SM4_BLOCK_SIZE;
  142. nbytes -= nblks * SM4_BLOCK_SIZE;
  143. }
  144. /* tail */
  145. if (walk.nbytes == walk.total && nbytes > 0) {
  146. u8 keystream[SM4_BLOCK_SIZE];
  147. sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
  148. crypto_xor_cpy(dst, src, keystream, nbytes);
  149. nbytes = 0;
  150. }
  151. kernel_neon_end();
  152. err = skcipher_walk_done(&walk, nbytes);
  153. }
  154. return err;
  155. }
  156. static int sm4_cfb_decrypt(struct skcipher_request *req)
  157. {
  158. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  159. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  160. struct skcipher_walk walk;
  161. unsigned int nbytes;
  162. int err;
  163. err = skcipher_walk_virt(&walk, req, false);
  164. while ((nbytes = walk.nbytes) > 0) {
  165. const u8 *src = walk.src.virt.addr;
  166. u8 *dst = walk.dst.virt.addr;
  167. unsigned int nblks;
  168. kernel_neon_begin();
  169. nblks = BYTES2BLKS(nbytes);
  170. if (nblks) {
  171. sm4_ce_cfb_dec(ctx->rkey_enc, dst, src, walk.iv, nblks);
  172. dst += nblks * SM4_BLOCK_SIZE;
  173. src += nblks * SM4_BLOCK_SIZE;
  174. nbytes -= nblks * SM4_BLOCK_SIZE;
  175. }
  176. /* tail */
  177. if (walk.nbytes == walk.total && nbytes > 0) {
  178. u8 keystream[SM4_BLOCK_SIZE];
  179. sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
  180. crypto_xor_cpy(dst, src, keystream, nbytes);
  181. nbytes = 0;
  182. }
  183. kernel_neon_end();
  184. err = skcipher_walk_done(&walk, nbytes);
  185. }
  186. return err;
  187. }
  188. static int sm4_ctr_crypt(struct skcipher_request *req)
  189. {
  190. struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  191. struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
  192. struct skcipher_walk walk;
  193. unsigned int nbytes;
  194. int err;
  195. err = skcipher_walk_virt(&walk, req, false);
  196. while ((nbytes = walk.nbytes) > 0) {
  197. const u8 *src = walk.src.virt.addr;
  198. u8 *dst = walk.dst.virt.addr;
  199. unsigned int nblks;
  200. kernel_neon_begin();
  201. nblks = BYTES2BLKS(nbytes);
  202. if (nblks) {
  203. sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
  204. dst += nblks * SM4_BLOCK_SIZE;
  205. src += nblks * SM4_BLOCK_SIZE;
  206. nbytes -= nblks * SM4_BLOCK_SIZE;
  207. }
  208. /* tail */
  209. if (walk.nbytes == walk.total && nbytes > 0) {
  210. u8 keystream[SM4_BLOCK_SIZE];
  211. sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
  212. crypto_inc(walk.iv, SM4_BLOCK_SIZE);
  213. crypto_xor_cpy(dst, src, keystream, nbytes);
  214. nbytes = 0;
  215. }
  216. kernel_neon_end();
  217. err = skcipher_walk_done(&walk, nbytes);
  218. }
  219. return err;
  220. }
  221. static struct skcipher_alg sm4_algs[] = {
  222. {
  223. .base = {
  224. .cra_name = "ecb(sm4)",
  225. .cra_driver_name = "ecb-sm4-ce",
  226. .cra_priority = 400,
  227. .cra_blocksize = SM4_BLOCK_SIZE,
  228. .cra_ctxsize = sizeof(struct sm4_ctx),
  229. .cra_module = THIS_MODULE,
  230. },
  231. .min_keysize = SM4_KEY_SIZE,
  232. .max_keysize = SM4_KEY_SIZE,
  233. .setkey = sm4_setkey,
  234. .encrypt = sm4_ecb_encrypt,
  235. .decrypt = sm4_ecb_decrypt,
  236. }, {
  237. .base = {
  238. .cra_name = "cbc(sm4)",
  239. .cra_driver_name = "cbc-sm4-ce",
  240. .cra_priority = 400,
  241. .cra_blocksize = SM4_BLOCK_SIZE,
  242. .cra_ctxsize = sizeof(struct sm4_ctx),
  243. .cra_module = THIS_MODULE,
  244. },
  245. .min_keysize = SM4_KEY_SIZE,
  246. .max_keysize = SM4_KEY_SIZE,
  247. .ivsize = SM4_BLOCK_SIZE,
  248. .setkey = sm4_setkey,
  249. .encrypt = sm4_cbc_encrypt,
  250. .decrypt = sm4_cbc_decrypt,
  251. }, {
  252. .base = {
  253. .cra_name = "cfb(sm4)",
  254. .cra_driver_name = "cfb-sm4-ce",
  255. .cra_priority = 400,
  256. .cra_blocksize = 1,
  257. .cra_ctxsize = sizeof(struct sm4_ctx),
  258. .cra_module = THIS_MODULE,
  259. },
  260. .min_keysize = SM4_KEY_SIZE,
  261. .max_keysize = SM4_KEY_SIZE,
  262. .ivsize = SM4_BLOCK_SIZE,
  263. .chunksize = SM4_BLOCK_SIZE,
  264. .setkey = sm4_setkey,
  265. .encrypt = sm4_cfb_encrypt,
  266. .decrypt = sm4_cfb_decrypt,
  267. }, {
  268. .base = {
  269. .cra_name = "ctr(sm4)",
  270. .cra_driver_name = "ctr-sm4-ce",
  271. .cra_priority = 400,
  272. .cra_blocksize = 1,
  273. .cra_ctxsize = sizeof(struct sm4_ctx),
  274. .cra_module = THIS_MODULE,
  275. },
  276. .min_keysize = SM4_KEY_SIZE,
  277. .max_keysize = SM4_KEY_SIZE,
  278. .ivsize = SM4_BLOCK_SIZE,
  279. .chunksize = SM4_BLOCK_SIZE,
  280. .setkey = sm4_setkey,
  281. .encrypt = sm4_ctr_crypt,
  282. .decrypt = sm4_ctr_crypt,
  283. }
  284. };
  285. static int __init sm4_init(void)
  286. {
  287. return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
  288. }
  289. static void __exit sm4_exit(void)
  290. {
  291. crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
  292. }
  293. module_cpu_feature_match(SM4, sm4_init);
  294. module_exit(sm4_exit);
  295. MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 Crypto Extensions");
  296. MODULE_ALIAS_CRYPTO("sm4-ce");
  297. MODULE_ALIAS_CRYPTO("sm4");
  298. MODULE_ALIAS_CRYPTO("ecb(sm4)");
  299. MODULE_ALIAS_CRYPTO("cbc(sm4)");
  300. MODULE_ALIAS_CRYPTO("cfb(sm4)");
  301. MODULE_ALIAS_CRYPTO("ctr(sm4)");
  302. MODULE_AUTHOR("Tianjia Zhang <[email protected]>");
  303. MODULE_LICENSE("GPL v2");