Skip to content

Commit 95515a8

Browse files
Base64url support (#1229)
1 parent 84faa7e commit 95515a8

File tree

5 files changed

+310
-79
lines changed

5 files changed

+310
-79
lines changed

include/aws/common/encoding.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,39 @@ int aws_hex_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, struct
6363
AWS_COMMON_API
6464
int aws_base64_compute_encoded_len(size_t to_encode_len, size_t *encoded_len);
6565

66+
/*
67+
* Computes the length necessary to store the output of aws_base64_url_encode call.
68+
* returns -1 on failure, and 0 on success. encoded_length will be set on
69+
* success.
70+
*/
71+
AWS_COMMON_API
72+
int aws_base64_url_compute_encoded_len(size_t to_encode_len, size_t *encoded_len);
73+
6674
/*
6775
* Base 64 encodes the contents of to_encode and stores the result in output.
6876
*/
6977
AWS_COMMON_API
7078
int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, struct aws_byte_buf *AWS_RESTRICT output);
7179

80+
/*
81+
* Base 64 URL encodes the contents of to_encode and stores the result in output.
82+
*/
83+
AWS_COMMON_API
84+
int aws_base64_url_encode(
85+
const struct aws_byte_cursor *AWS_RESTRICT to_encode,
86+
struct aws_byte_buf *AWS_RESTRICT output);
87+
7288
/*
7389
* Computes the length necessary to store the output of aws_base64_decode call.
90+
* Note: works on both regular and url base64.
7491
* returns -1 on failure, and 0 on success. decoded_len will be set on success.
7592
*/
7693
AWS_COMMON_API
7794
int aws_base64_compute_decoded_len(const struct aws_byte_cursor *AWS_RESTRICT to_decode, size_t *decoded_len);
7895

7996
/*
8097
* Base 64 decodes the contents of to_decode and stores the result in output.
98+
* Note: works on both regular and url base64.
8199
*/
82100
AWS_COMMON_API
83101
int aws_base64_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, struct aws_byte_buf *AWS_RESTRICT output);

source/arch/intel/encoding_avx2.c

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,28 @@ static inline __m256i translate_exact(__m256i in, uint8_t match, uint8_t decode)
5252
* on decode failure, returns false, else returns true on success.
5353
*/
5454
static inline bool decode_vec(__m256i *in) {
55-
__m256i tmp1, tmp2, tmp3;
55+
__m256i tmp1, tmp2, tmp3, tmp4, tmp5;
5656

5757
/*
5858
* Base64 decoding table, see RFC4648
5959
*
6060
* Note that we use multiple vector registers to try to allow the CPU to
61-
* paralellize the merging ORs
61+
* parallelize the merging ORs
6262
*/
6363
tmp1 = translate_range(*in, 'A', 'Z', 0 + 1);
6464
tmp2 = translate_range(*in, 'a', 'z', 26 + 1);
6565
tmp3 = translate_range(*in, '0', '9', 52 + 1);
66-
tmp1 = _mm256_or_si256(tmp1, translate_exact(*in, '+', 62 + 1));
67-
tmp2 = _mm256_or_si256(tmp2, translate_exact(*in, '/', 63 + 1));
66+
// Handle both '+' and '-' for value 62
67+
tmp4 = translate_exact(*in, '+', 62 + 1);
68+
tmp4 = _mm256_or_si256(tmp4, translate_exact(*in, '-', 62 + 1));
69+
70+
// Handle both '/' and '_' for value 63
71+
tmp5 = translate_exact(*in, '/', 63 + 1);
72+
tmp5 = _mm256_or_si256(tmp5, translate_exact(*in, '_', 63 + 1));
73+
74+
// Combine all results
75+
tmp1 = _mm256_or_si256(tmp1, tmp4);
76+
tmp2 = _mm256_or_si256(tmp2, tmp5);
6877
tmp3 = _mm256_or_si256(tmp3, _mm256_or_si256(tmp1, tmp2));
6978

7079
/*

source/encoding.c

Lines changed: 127 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55

66
#include <aws/common/encoding.h>
77

8+
#include <aws/common/logging.h>
89
#include <ctype.h>
910
#include <stdlib.h>
1011

1112
#ifdef USE_SIMD_ENCODING
1213
size_t aws_common_private_base64_decode_sse41(const unsigned char *in, unsigned char *out, size_t len);
13-
void aws_common_private_base64_encode_sse41(const unsigned char *in, unsigned char *out, size_t len);
14+
void aws_common_private_base64_encode_sse41(
15+
const unsigned char *in,
16+
unsigned char *out,
17+
size_t len,
18+
bool url_safe_encoding);
1419
bool aws_common_private_has_avx2(void);
1520
#else
1621
/*
@@ -25,10 +30,15 @@ static inline size_t aws_common_private_base64_decode_sse41(const unsigned char
2530
AWS_ASSERT(false);
2631
return SIZE_MAX; /* unreachable */
2732
}
28-
static inline void aws_common_private_base64_encode_sse41(const unsigned char *in, unsigned char *out, size_t len) {
33+
static inline void aws_common_private_base64_encode_sse41(
34+
const unsigned char *in,
35+
unsigned char *out,
36+
size_t len,
37+
bool url_safe_encoding) {
2938
(void)in;
3039
(void)out;
3140
(void)len;
41+
(void)url_safe_encoding;
3242
AWS_ASSERT(false);
3343
}
3444
static inline bool aws_common_private_has_avx2(void) {
@@ -40,17 +50,18 @@ static const uint8_t *HEX_CHARS = (const uint8_t *)"0123456789abcdef";
4050

4151
static const uint8_t BASE64_SENTINEL_VALUE = 0xff;
4252
static const uint8_t BASE64_ENCODING_TABLE[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
53+
static const uint8_t BASE64_URL_ENCODING_TABLE[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
4354

4455
/* in this table, 0xDD is an invalid decoded value, if you have to do byte counting for any reason, there's 16 bytes
4556
* per row. Reformatting is turned off to make sure this stays as 16 bytes per line. */
4657
/* clang-format off */
4758
static const uint8_t BASE64_DECODING_TABLE[256] = {
4859
64, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
4960
0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
50-
0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 62, 0xDD, 0xDD, 0xDD, 63,
61+
0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 62, 0xDD, 62, 0xDD, 63,
5162
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 0xDD, 0xDD, 0xDD, 255, 0xDD, 0xDD,
5263
0xDD, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
53-
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
64+
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0xDD, 0xDD, 0xDD, 0xDD, 63,
5465
0xDD, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
5566
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
5667
0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
@@ -216,73 +227,102 @@ int aws_hex_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, struct
216227
}
217228

218229
int aws_base64_compute_encoded_len(size_t to_encode_len, size_t *encoded_len) {
219-
AWS_ASSERT(encoded_len);
230+
AWS_ERROR_PRECONDITION(encoded_len);
220231

221232
/* For every 3 bytes (rounded up) of unencoded input, there will be 4 ascii characters of encoded output.
222233
* Rounding is because the output will be padded with '=' chars if necessary to make it divisible by 4. */
223234

224235
/* adding 2 before dividing by 3 is a trick to round up during division */
225-
size_t tmp = to_encode_len + 2;
236+
size_t tmp = 0;
226237

227-
if (AWS_UNLIKELY(tmp < to_encode_len)) {
228-
return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
238+
if (AWS_UNLIKELY(aws_add_size_checked(to_encode_len, 2, &tmp))) {
239+
return AWS_OP_ERR;
229240
}
230241

231242
tmp /= 3;
232-
size_t overflow_check = tmp;
233-
tmp = 4 * tmp;
243+
if (AWS_UNLIKELY(aws_mul_size_checked(tmp, 4, &tmp))) {
244+
return AWS_OP_ERR;
245+
}
234246

235-
if (AWS_UNLIKELY(tmp < overflow_check)) {
236-
return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
247+
*encoded_len = tmp;
248+
249+
return AWS_OP_SUCCESS;
250+
}
251+
252+
int aws_base64_url_compute_encoded_len(size_t to_encode_len, size_t *encoded_len) {
253+
AWS_ERROR_PRECONDITION(encoded_len);
254+
255+
/* Just do direct math for each each 6 bits map to char. +5 is same trick to as before to round up. */
256+
257+
size_t tmp = 0;
258+
259+
if (AWS_UNLIKELY(aws_mul_size_checked(to_encode_len, 8, &tmp))) {
260+
return AWS_OP_ERR;
261+
}
262+
263+
if (AWS_UNLIKELY(aws_add_size_checked(tmp, 5, &tmp))) {
264+
return AWS_OP_ERR;
237265
}
238266

267+
tmp /= 6;
268+
239269
*encoded_len = tmp;
240270

241271
return AWS_OP_SUCCESS;
242272
}
243273

274+
bool s_ispadding(uint8_t ch) {
275+
return ch == '=';
276+
}
277+
244278
int aws_base64_compute_decoded_len(const struct aws_byte_cursor *AWS_RESTRICT to_decode, size_t *decoded_len) {
245279
AWS_ASSERT(to_decode);
246280
AWS_ASSERT(decoded_len);
247281

248-
const size_t len = to_decode->len;
249-
const uint8_t *input = to_decode->ptr;
282+
/* strip padding */
283+
struct aws_byte_cursor trimmed = aws_byte_cursor_right_trim_pred(to_decode, s_ispadding);
284+
285+
const size_t len = trimmed.len;
250286

251287
if (len == 0) {
252288
*decoded_len = 0;
253289
return AWS_OP_SUCCESS;
254290
}
255291

256-
/* ensure it's divisible by 4 */
257-
if (AWS_UNLIKELY(len & 0x03)) {
292+
/* impossible len */
293+
if (AWS_UNLIKELY(len % 4 == 1)) {
258294
return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
259295
}
260296

261-
/* For every 4 ascii characters of encoded input, there will be 3 bytes of decoded output (deal with padding later)
262-
* decoded_len = 3/4 * len <-- note that result will be smaller then len, so overflow can be avoided
263-
* = (len / 4) * 3 <-- divide before multiply to avoid overflow
264-
*/
265-
size_t decoded_len_tmp = (len / 4) * 3;
297+
/* For every 4 ascii characters of encoded input, there will be 3 bytes of decoded output. */
298+
size_t decoded_len_tmp = 0;
266299

267-
/* But last two ascii chars might be padding. */
268-
AWS_ASSERT(len >= 4); /* we checked earlier len != 0, and was divisible by 4 */
269-
size_t padding = 0;
270-
if (input[len - 1] == '=' && input[len - 2] == '=') { /*last two chars are = */
271-
padding = 2;
272-
} else if (input[len - 1] == '=') { /*last char is = */
273-
padding = 1;
300+
if (aws_mul_size_checked(len, 3, &decoded_len_tmp)) {
301+
return AWS_OP_ERR;
274302
}
275303

276-
*decoded_len = decoded_len_tmp - padding;
304+
decoded_len_tmp >>= 2;
305+
306+
/* But last two ascii chars might be padding. */
307+
308+
*decoded_len = decoded_len_tmp;
277309
return AWS_OP_SUCCESS;
278310
}
279311

280-
int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, struct aws_byte_buf *AWS_RESTRICT output) {
281-
AWS_ASSERT(to_encode->len == 0 || to_encode->ptr != NULL);
312+
static int s_base64_encode(
313+
const struct aws_byte_cursor *AWS_RESTRICT to_encode,
314+
struct aws_byte_buf *AWS_RESTRICT output,
315+
bool do_url_safe_encoding) {
282316

283317
size_t encoded_length = 0;
284-
if (AWS_UNLIKELY(aws_base64_compute_encoded_len(to_encode->len, &encoded_length))) {
285-
return AWS_OP_ERR;
318+
if (do_url_safe_encoding) {
319+
if (AWS_UNLIKELY(aws_base64_url_compute_encoded_len(to_encode->len, &encoded_length))) {
320+
return AWS_OP_ERR;
321+
}
322+
} else {
323+
if (AWS_UNLIKELY(aws_base64_compute_encoded_len(to_encode->len, &encoded_length))) {
324+
return AWS_OP_ERR;
325+
}
286326
}
287327

288328
size_t needed_capacity = 0;
@@ -296,8 +336,14 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
296336

297337
AWS_ASSERT(needed_capacity == 0 || output->buffer != NULL);
298338

299-
if (aws_common_private_has_avx2()) {
300-
aws_common_private_base64_encode_sse41(to_encode->ptr, output->buffer + output->len, to_encode->len);
339+
/*
340+
* Note: avx2 impl currently does not support url base64 (no padding -> output not divisible by 4 -> it writes out
341+
* of bounds). Just use software version for now (since need for base64 url is small) instead of hacking together
342+
* half hearted avx2 impl.
343+
*/
344+
if (!do_url_safe_encoding && aws_common_private_has_avx2()) {
345+
aws_common_private_base64_encode_sse41(
346+
to_encode->ptr, output->buffer + output->len, to_encode->len, do_url_safe_encoding);
301347
output->len += encoded_length;
302348
return AWS_OP_SUCCESS;
303349
}
@@ -307,7 +353,9 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
307353
size_t remainder_count = (buffer_length % 3);
308354
size_t str_index = output->len;
309355

310-
for (size_t i = 0; i < to_encode->len; i += 3) {
356+
const uint8_t *encoding_table = do_url_safe_encoding ? BASE64_URL_ENCODING_TABLE : BASE64_ENCODING_TABLE;
357+
358+
for (size_t i = 0; i < buffer_length; i += 3) {
311359
uint32_t block = to_encode->ptr[i];
312360

313361
block <<= 8;
@@ -316,17 +364,21 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
316364
}
317365

318366
block <<= 8;
319-
if (AWS_LIKELY(i + 2 < to_encode->len)) {
367+
if (AWS_LIKELY(i + 2 < buffer_length)) {
320368
block = block | to_encode->ptr[i + 2];
321369
}
322370

323-
output->buffer[str_index++] = BASE64_ENCODING_TABLE[(block >> 18) & 0x3F];
324-
output->buffer[str_index++] = BASE64_ENCODING_TABLE[(block >> 12) & 0x3F];
325-
output->buffer[str_index++] = BASE64_ENCODING_TABLE[(block >> 6) & 0x3F];
326-
output->buffer[str_index++] = BASE64_ENCODING_TABLE[block & 0x3F];
371+
output->buffer[str_index++] = encoding_table[(block >> 18) & 0x3F];
372+
output->buffer[str_index++] = encoding_table[(block >> 12) & 0x3F];
373+
if (AWS_LIKELY(i + 1 < buffer_length)) {
374+
output->buffer[str_index++] = encoding_table[(block >> 6) & 0x3F];
375+
if (AWS_LIKELY(i + 2 < buffer_length)) {
376+
output->buffer[str_index++] = encoding_table[block & 0x3F];
377+
}
378+
}
327379
}
328380

329-
if (remainder_count > 0) {
381+
if (!do_url_safe_encoding && remainder_count > 0) {
330382
output->buffer[output->len + block_count * 4 - 1] = '=';
331383
if (remainder_count == 1) {
332384
output->buffer[output->len + block_count * 4 - 2] = '=';
@@ -338,6 +390,16 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
338390
return AWS_OP_SUCCESS;
339391
}
340392

393+
int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, struct aws_byte_buf *AWS_RESTRICT output) {
394+
return s_base64_encode(to_encode, output, false);
395+
}
396+
397+
int aws_base64_url_encode(
398+
const struct aws_byte_cursor *AWS_RESTRICT to_encode,
399+
struct aws_byte_buf *AWS_RESTRICT output) {
400+
return s_base64_encode(to_encode, output, true);
401+
}
402+
341403
static inline int s_base64_get_decoded_value(unsigned char to_decode, uint8_t *value, int8_t allow_sentinel) {
342404

343405
uint8_t decode_value = BASE64_DECODING_TABLE[(size_t)to_decode];
@@ -360,7 +422,14 @@ int aws_base64_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, stru
360422
return aws_raise_error(AWS_ERROR_SHORT_BUFFER);
361423
}
362424

363-
if (aws_common_private_has_avx2()) {
425+
/*
426+
* Note: avx2 version relies on input being padded to 4 byte boundary.
427+
* Regular base64 is always padded to 4, but for url variant padding is skipped.
428+
* Fall back to software version for inputs that are not divisible by 4 cleanly (aka base64 url),
429+
* as those are a corner case and we dont need them as often.
430+
* Reconsider, writing intrinsic version is usage becomes more widespread.
431+
*/
432+
if ((to_decode->len % 4 == 0) && aws_common_private_has_avx2()) {
364433
size_t result = aws_common_private_base64_decode_sse41(to_decode->ptr, output->buffer, to_decode->len);
365434
if (result == SIZE_MAX) {
366435
return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
@@ -370,7 +439,8 @@ int aws_base64_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, stru
370439
return AWS_OP_SUCCESS;
371440
}
372441

373-
int64_t block_count = (int64_t)to_decode->len / 4;
442+
int64_t block_count = (int64_t)(to_decode->len + 3) / 4;
443+
size_t remainder = to_decode->len % 4;
374444
size_t string_index = 0;
375445
uint8_t value1 = 0, value2 = 0, value3 = 0, value4 = 0;
376446
int64_t buffer_index = 0;
@@ -394,9 +464,19 @@ int aws_base64_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, stru
394464

395465
if (buffer_index >= 0) {
396466
if (s_base64_get_decoded_value(to_decode->ptr[string_index++], &value1, 0) ||
397-
s_base64_get_decoded_value(to_decode->ptr[string_index++], &value2, 0) ||
398-
s_base64_get_decoded_value(to_decode->ptr[string_index++], &value3, 1) ||
399-
s_base64_get_decoded_value(to_decode->ptr[string_index], &value4, 1)) {
467+
s_base64_get_decoded_value(to_decode->ptr[string_index++], &value2, 0)) {
468+
return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
469+
}
470+
471+
value3 = BASE64_SENTINEL_VALUE;
472+
value4 = BASE64_SENTINEL_VALUE;
473+
474+
if ((remainder == 3 || remainder == 0) &&
475+
s_base64_get_decoded_value(to_decode->ptr[string_index++], &value3, 1)) {
476+
return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
477+
}
478+
479+
if (remainder == 0 && s_base64_get_decoded_value(to_decode->ptr[string_index++], &value4, 1)) {
400480
return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
401481
}
402482

0 commit comments

Comments
 (0)