55
66#include <aws/common/encoding.h>
77
8+ #include <aws/common/logging.h>
89#include <ctype.h>
910#include <stdlib.h>
1011
1112#ifdef USE_SIMD_ENCODING
1213size_t aws_common_private_base64_decode_sse41 (const unsigned char * in , unsigned char * out , size_t len );
13- void aws_common_private_base64_encode_sse41 (const unsigned char * in , unsigned char * out , size_t len );
14+ void aws_common_private_base64_encode_sse41 (
15+ const unsigned char * in ,
16+ unsigned char * out ,
17+ size_t len ,
18+ bool url_safe_encoding );
1419bool aws_common_private_has_avx2 (void );
1520#else
1621/*
@@ -25,10 +30,15 @@ static inline size_t aws_common_private_base64_decode_sse41(const unsigned char
2530 AWS_ASSERT (false);
2631 return SIZE_MAX ; /* unreachable */
2732}
28- static inline void aws_common_private_base64_encode_sse41 (const unsigned char * in , unsigned char * out , size_t len ) {
33+ static inline void aws_common_private_base64_encode_sse41 (
34+ const unsigned char * in ,
35+ unsigned char * out ,
36+ size_t len ,
37+ bool url_safe_encoding ) {
2938 (void )in ;
3039 (void )out ;
3140 (void )len ;
41+ (void )url_safe_encoding ;
3242 AWS_ASSERT (false);
3343}
3444static inline bool aws_common_private_has_avx2 (void ) {
@@ -40,17 +50,18 @@ static const uint8_t *HEX_CHARS = (const uint8_t *)"0123456789abcdef";
4050
4151static const uint8_t BASE64_SENTINEL_VALUE = 0xff ;
4252static const uint8_t BASE64_ENCODING_TABLE [] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" ;
53+ static const uint8_t BASE64_URL_ENCODING_TABLE [] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" ;
4354
4455/* in this table, 0xDD is an invalid decoded value, if you have to do byte counting for any reason, there's 16 bytes
4556 * per row. Reformatting is turned off to make sure this stays as 16 bytes per line. */
4657/* clang-format off */
4758static const uint8_t BASE64_DECODING_TABLE [256 ] = {
4859 64 , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD ,
4960 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD ,
50- 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 62 , 0xDD , 0xDD , 0xDD , 63 ,
61+ 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 62 , 0xDD , 62 , 0xDD , 63 ,
5162 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 , 61 , 0xDD , 0xDD , 0xDD , 255 , 0xDD , 0xDD ,
5263 0xDD , 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 ,
53- 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD ,
64+ 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 0xDD , 0xDD , 0xDD , 0xDD , 63 ,
5465 0xDD , 26 , 27 , 28 , 29 , 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 ,
5566 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 , 50 , 51 , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD ,
5667 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD , 0xDD ,
@@ -216,73 +227,102 @@ int aws_hex_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, struct
216227}
217228
218229int aws_base64_compute_encoded_len (size_t to_encode_len , size_t * encoded_len ) {
219- AWS_ASSERT (encoded_len );
230+ AWS_ERROR_PRECONDITION (encoded_len );
220231
221232 /* For every 3 bytes (rounded up) of unencoded input, there will be 4 ascii characters of encoded output.
222233 * Rounding is because the output will be padded with '=' chars if necessary to make it divisible by 4. */
223234
224235 /* adding 2 before dividing by 3 is a trick to round up during division */
225- size_t tmp = to_encode_len + 2 ;
236+ size_t tmp = 0 ;
226237
227- if (AWS_UNLIKELY (tmp < to_encode_len )) {
228- return aws_raise_error ( AWS_ERROR_OVERFLOW_DETECTED ) ;
238+ if (AWS_UNLIKELY (aws_add_size_checked ( to_encode_len , 2 , & tmp ) )) {
239+ return AWS_OP_ERR ;
229240 }
230241
231242 tmp /= 3 ;
232- size_t overflow_check = tmp ;
233- tmp = 4 * tmp ;
243+ if (AWS_UNLIKELY (aws_mul_size_checked (tmp , 4 , & tmp ))) {
244+ return AWS_OP_ERR ;
245+ }
234246
235- if (AWS_UNLIKELY (tmp < overflow_check )) {
236- return aws_raise_error (AWS_ERROR_OVERFLOW_DETECTED );
247+ * encoded_len = tmp ;
248+
249+ return AWS_OP_SUCCESS ;
250+ }
251+
252+ int aws_base64_url_compute_encoded_len (size_t to_encode_len , size_t * encoded_len ) {
253+ AWS_ERROR_PRECONDITION (encoded_len );
254+
255+ /* Just do direct math for each each 6 bits map to char. +5 is same trick to as before to round up. */
256+
257+ size_t tmp = 0 ;
258+
259+ if (AWS_UNLIKELY (aws_mul_size_checked (to_encode_len , 8 , & tmp ))) {
260+ return AWS_OP_ERR ;
261+ }
262+
263+ if (AWS_UNLIKELY (aws_add_size_checked (tmp , 5 , & tmp ))) {
264+ return AWS_OP_ERR ;
237265 }
238266
267+ tmp /= 6 ;
268+
239269 * encoded_len = tmp ;
240270
241271 return AWS_OP_SUCCESS ;
242272}
243273
274+ bool s_ispadding (uint8_t ch ) {
275+ return ch == '=' ;
276+ }
277+
244278int aws_base64_compute_decoded_len (const struct aws_byte_cursor * AWS_RESTRICT to_decode , size_t * decoded_len ) {
245279 AWS_ASSERT (to_decode );
246280 AWS_ASSERT (decoded_len );
247281
248- const size_t len = to_decode -> len ;
249- const uint8_t * input = to_decode -> ptr ;
282+ /* strip padding */
283+ struct aws_byte_cursor trimmed = aws_byte_cursor_right_trim_pred (to_decode , s_ispadding );
284+
285+ const size_t len = trimmed .len ;
250286
251287 if (len == 0 ) {
252288 * decoded_len = 0 ;
253289 return AWS_OP_SUCCESS ;
254290 }
255291
256- /* ensure it's divisible by 4 */
257- if (AWS_UNLIKELY (len & 0x03 )) {
292+ /* impossible len */
293+ if (AWS_UNLIKELY (len % 4 == 1 )) {
258294 return aws_raise_error (AWS_ERROR_INVALID_BASE64_STR );
259295 }
260296
261- /* For every 4 ascii characters of encoded input, there will be 3 bytes of decoded output (deal with padding later)
262- * decoded_len = 3/4 * len <-- note that result will be smaller then len, so overflow can be avoided
263- * = (len / 4) * 3 <-- divide before multiply to avoid overflow
264- */
265- size_t decoded_len_tmp = (len / 4 ) * 3 ;
297+ /* For every 4 ascii characters of encoded input, there will be 3 bytes of decoded output. */
298+ size_t decoded_len_tmp = 0 ;
266299
267- /* But last two ascii chars might be padding. */
268- AWS_ASSERT (len >= 4 ); /* we checked earlier len != 0, and was divisible by 4 */
269- size_t padding = 0 ;
270- if (input [len - 1 ] == '=' && input [len - 2 ] == '=' ) { /*last two chars are = */
271- padding = 2 ;
272- } else if (input [len - 1 ] == '=' ) { /*last char is = */
273- padding = 1 ;
300+ if (aws_mul_size_checked (len , 3 , & decoded_len_tmp )) {
301+ return AWS_OP_ERR ;
274302 }
275303
276- * decoded_len = decoded_len_tmp - padding ;
304+ decoded_len_tmp >>= 2 ;
305+
306+ /* But last two ascii chars might be padding. */
307+
308+ * decoded_len = decoded_len_tmp ;
277309 return AWS_OP_SUCCESS ;
278310}
279311
280- int aws_base64_encode (const struct aws_byte_cursor * AWS_RESTRICT to_encode , struct aws_byte_buf * AWS_RESTRICT output ) {
281- AWS_ASSERT (to_encode -> len == 0 || to_encode -> ptr != NULL );
312+ static int s_base64_encode (
313+ const struct aws_byte_cursor * AWS_RESTRICT to_encode ,
314+ struct aws_byte_buf * AWS_RESTRICT output ,
315+ bool do_url_safe_encoding ) {
282316
283317 size_t encoded_length = 0 ;
284- if (AWS_UNLIKELY (aws_base64_compute_encoded_len (to_encode -> len , & encoded_length ))) {
285- return AWS_OP_ERR ;
318+ if (do_url_safe_encoding ) {
319+ if (AWS_UNLIKELY (aws_base64_url_compute_encoded_len (to_encode -> len , & encoded_length ))) {
320+ return AWS_OP_ERR ;
321+ }
322+ } else {
323+ if (AWS_UNLIKELY (aws_base64_compute_encoded_len (to_encode -> len , & encoded_length ))) {
324+ return AWS_OP_ERR ;
325+ }
286326 }
287327
288328 size_t needed_capacity = 0 ;
@@ -296,8 +336,14 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
296336
297337 AWS_ASSERT (needed_capacity == 0 || output -> buffer != NULL );
298338
299- if (aws_common_private_has_avx2 ()) {
300- aws_common_private_base64_encode_sse41 (to_encode -> ptr , output -> buffer + output -> len , to_encode -> len );
339+ /*
340+ * Note: avx2 impl currently does not support url base64 (no padding -> output not divisible by 4 -> it writes out
341+ * of bounds). Just use software version for now (since need for base64 url is small) instead of hacking together
342+ * half hearted avx2 impl.
343+ */
344+ if (!do_url_safe_encoding && aws_common_private_has_avx2 ()) {
345+ aws_common_private_base64_encode_sse41 (
346+ to_encode -> ptr , output -> buffer + output -> len , to_encode -> len , do_url_safe_encoding );
301347 output -> len += encoded_length ;
302348 return AWS_OP_SUCCESS ;
303349 }
@@ -307,7 +353,9 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
307353 size_t remainder_count = (buffer_length % 3 );
308354 size_t str_index = output -> len ;
309355
310- for (size_t i = 0 ; i < to_encode -> len ; i += 3 ) {
356+ const uint8_t * encoding_table = do_url_safe_encoding ? BASE64_URL_ENCODING_TABLE : BASE64_ENCODING_TABLE ;
357+
358+ for (size_t i = 0 ; i < buffer_length ; i += 3 ) {
311359 uint32_t block = to_encode -> ptr [i ];
312360
313361 block <<= 8 ;
@@ -316,17 +364,21 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
316364 }
317365
318366 block <<= 8 ;
319- if (AWS_LIKELY (i + 2 < to_encode -> len )) {
367+ if (AWS_LIKELY (i + 2 < buffer_length )) {
320368 block = block | to_encode -> ptr [i + 2 ];
321369 }
322370
323- output -> buffer [str_index ++ ] = BASE64_ENCODING_TABLE [(block >> 18 ) & 0x3F ];
324- output -> buffer [str_index ++ ] = BASE64_ENCODING_TABLE [(block >> 12 ) & 0x3F ];
325- output -> buffer [str_index ++ ] = BASE64_ENCODING_TABLE [(block >> 6 ) & 0x3F ];
326- output -> buffer [str_index ++ ] = BASE64_ENCODING_TABLE [block & 0x3F ];
371+ output -> buffer [str_index ++ ] = encoding_table [(block >> 18 ) & 0x3F ];
372+ output -> buffer [str_index ++ ] = encoding_table [(block >> 12 ) & 0x3F ];
373+ if (AWS_LIKELY (i + 1 < buffer_length )) {
374+ output -> buffer [str_index ++ ] = encoding_table [(block >> 6 ) & 0x3F ];
375+ if (AWS_LIKELY (i + 2 < buffer_length )) {
376+ output -> buffer [str_index ++ ] = encoding_table [block & 0x3F ];
377+ }
378+ }
327379 }
328380
329- if (remainder_count > 0 ) {
381+ if (! do_url_safe_encoding && remainder_count > 0 ) {
330382 output -> buffer [output -> len + block_count * 4 - 1 ] = '=' ;
331383 if (remainder_count == 1 ) {
332384 output -> buffer [output -> len + block_count * 4 - 2 ] = '=' ;
@@ -338,6 +390,16 @@ int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, stru
338390 return AWS_OP_SUCCESS ;
339391}
340392
393+ int aws_base64_encode (const struct aws_byte_cursor * AWS_RESTRICT to_encode , struct aws_byte_buf * AWS_RESTRICT output ) {
394+ return s_base64_encode (to_encode , output , false);
395+ }
396+
397+ int aws_base64_url_encode (
398+ const struct aws_byte_cursor * AWS_RESTRICT to_encode ,
399+ struct aws_byte_buf * AWS_RESTRICT output ) {
400+ return s_base64_encode (to_encode , output , true);
401+ }
402+
341403static inline int s_base64_get_decoded_value (unsigned char to_decode , uint8_t * value , int8_t allow_sentinel ) {
342404
343405 uint8_t decode_value = BASE64_DECODING_TABLE [(size_t )to_decode ];
@@ -360,7 +422,14 @@ int aws_base64_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, stru
360422 return aws_raise_error (AWS_ERROR_SHORT_BUFFER );
361423 }
362424
363- if (aws_common_private_has_avx2 ()) {
425+ /*
426+ * Note: avx2 version relies on input being padded to 4 byte boundary.
427+ * Regular base64 is always padded to 4, but for url variant padding is skipped.
428+ * Fall back to software version for inputs that are not divisible by 4 cleanly (aka base64 url),
429+ * as those are a corner case and we dont need them as often.
430+ * Reconsider, writing intrinsic version is usage becomes more widespread.
431+ */
432+ if ((to_decode -> len % 4 == 0 ) && aws_common_private_has_avx2 ()) {
364433 size_t result = aws_common_private_base64_decode_sse41 (to_decode -> ptr , output -> buffer , to_decode -> len );
365434 if (result == SIZE_MAX ) {
366435 return aws_raise_error (AWS_ERROR_INVALID_BASE64_STR );
@@ -370,7 +439,8 @@ int aws_base64_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, stru
370439 return AWS_OP_SUCCESS ;
371440 }
372441
373- int64_t block_count = (int64_t )to_decode -> len / 4 ;
442+ int64_t block_count = (int64_t )(to_decode -> len + 3 ) / 4 ;
443+ size_t remainder = to_decode -> len % 4 ;
374444 size_t string_index = 0 ;
375445 uint8_t value1 = 0 , value2 = 0 , value3 = 0 , value4 = 0 ;
376446 int64_t buffer_index = 0 ;
@@ -394,9 +464,19 @@ int aws_base64_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, stru
394464
395465 if (buffer_index >= 0 ) {
396466 if (s_base64_get_decoded_value (to_decode -> ptr [string_index ++ ], & value1 , 0 ) ||
397- s_base64_get_decoded_value (to_decode -> ptr [string_index ++ ], & value2 , 0 ) ||
398- s_base64_get_decoded_value (to_decode -> ptr [string_index ++ ], & value3 , 1 ) ||
399- s_base64_get_decoded_value (to_decode -> ptr [string_index ], & value4 , 1 )) {
467+ s_base64_get_decoded_value (to_decode -> ptr [string_index ++ ], & value2 , 0 )) {
468+ return aws_raise_error (AWS_ERROR_INVALID_BASE64_STR );
469+ }
470+
471+ value3 = BASE64_SENTINEL_VALUE ;
472+ value4 = BASE64_SENTINEL_VALUE ;
473+
474+ if ((remainder == 3 || remainder == 0 ) &&
475+ s_base64_get_decoded_value (to_decode -> ptr [string_index ++ ], & value3 , 1 )) {
476+ return aws_raise_error (AWS_ERROR_INVALID_BASE64_STR );
477+ }
478+
479+ if (remainder == 0 && s_base64_get_decoded_value (to_decode -> ptr [string_index ++ ], & value4 , 1 )) {
400480 return aws_raise_error (AWS_ERROR_INVALID_BASE64_STR );
401481 }
402482
0 commit comments