sm3_base.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /* SPDX-License-Identifier: GPL-2.0-only */
  2. /*
  3. * sm3_base.h - core logic for SM3 implementations
  4. *
  5. * Copyright (C) 2017 ARM Limited or its affiliates.
  6. * Written by Gilad Ben-Yossef <[email protected]>
  7. */
  8. #ifndef _CRYPTO_SM3_BASE_H
  9. #define _CRYPTO_SM3_BASE_H
  10. #include <crypto/internal/hash.h>
  11. #include <crypto/sm3.h>
  12. #include <linux/crypto.h>
  13. #include <linux/module.h>
  14. #include <linux/string.h>
  15. #include <asm/unaligned.h>
  16. typedef void (sm3_block_fn)(struct sm3_state *sst, u8 const *src, int blocks);
  17. static inline int sm3_base_init(struct shash_desc *desc)
  18. {
  19. struct sm3_state *sctx = shash_desc_ctx(desc);
  20. sctx->state[0] = SM3_IVA;
  21. sctx->state[1] = SM3_IVB;
  22. sctx->state[2] = SM3_IVC;
  23. sctx->state[3] = SM3_IVD;
  24. sctx->state[4] = SM3_IVE;
  25. sctx->state[5] = SM3_IVF;
  26. sctx->state[6] = SM3_IVG;
  27. sctx->state[7] = SM3_IVH;
  28. sctx->count = 0;
  29. return 0;
  30. }
  31. static inline int sm3_base_do_update(struct shash_desc *desc,
  32. const u8 *data,
  33. unsigned int len,
  34. sm3_block_fn *block_fn)
  35. {
  36. struct sm3_state *sctx = shash_desc_ctx(desc);
  37. unsigned int partial = sctx->count % SM3_BLOCK_SIZE;
  38. sctx->count += len;
  39. if (unlikely((partial + len) >= SM3_BLOCK_SIZE)) {
  40. int blocks;
  41. if (partial) {
  42. int p = SM3_BLOCK_SIZE - partial;
  43. memcpy(sctx->buffer + partial, data, p);
  44. data += p;
  45. len -= p;
  46. block_fn(sctx, sctx->buffer, 1);
  47. }
  48. blocks = len / SM3_BLOCK_SIZE;
  49. len %= SM3_BLOCK_SIZE;
  50. if (blocks) {
  51. block_fn(sctx, data, blocks);
  52. data += blocks * SM3_BLOCK_SIZE;
  53. }
  54. partial = 0;
  55. }
  56. if (len)
  57. memcpy(sctx->buffer + partial, data, len);
  58. return 0;
  59. }
  60. static inline int sm3_base_do_finalize(struct shash_desc *desc,
  61. sm3_block_fn *block_fn)
  62. {
  63. const int bit_offset = SM3_BLOCK_SIZE - sizeof(__be64);
  64. struct sm3_state *sctx = shash_desc_ctx(desc);
  65. __be64 *bits = (__be64 *)(sctx->buffer + bit_offset);
  66. unsigned int partial = sctx->count % SM3_BLOCK_SIZE;
  67. sctx->buffer[partial++] = 0x80;
  68. if (partial > bit_offset) {
  69. memset(sctx->buffer + partial, 0x0, SM3_BLOCK_SIZE - partial);
  70. partial = 0;
  71. block_fn(sctx, sctx->buffer, 1);
  72. }
  73. memset(sctx->buffer + partial, 0x0, bit_offset - partial);
  74. *bits = cpu_to_be64(sctx->count << 3);
  75. block_fn(sctx, sctx->buffer, 1);
  76. return 0;
  77. }
  78. static inline int sm3_base_finish(struct shash_desc *desc, u8 *out)
  79. {
  80. struct sm3_state *sctx = shash_desc_ctx(desc);
  81. __be32 *digest = (__be32 *)out;
  82. int i;
  83. for (i = 0; i < SM3_DIGEST_SIZE / sizeof(__be32); i++)
  84. put_unaligned_be32(sctx->state[i], digest++);
  85. memzero_explicit(sctx, sizeof(*sctx));
  86. return 0;
  87. }
  88. #endif /* _CRYPTO_SM3_BASE_H */