11# aeskw.py - implementation of AES key wrapping
22# coding: utf-8
33#
4- # Copyright (C) 2014 Arthur de Jong
4+ # Copyright (C) 2014-2015 Arthur de Jong
55#
66# This library is free software; you can redistribute it and/or
77# modify it under the terms of the GNU Lesser General Public
2020
2121"""Implement key wrapping as described in RFC 3394 and RFC 5649."""
2222
23+ import binascii
24+
2325from Crypto .Cipher import AES
2426from Crypto .Util .number import bytes_to_long , long_to_bytes
2527from Crypto .Util .strxor import strxor
@@ -31,8 +33,8 @@ def _split(value):
3133 return value [:8 ], value [8 :]
3234
3335
34- RFC3394_IV = 'a6a6a6a6a6a6a6a6' . decode ( 'hex ' )
35- RFC5649_IV = 'a65959a6' . decode ( 'hex ' )
36+ RFC3394_IV = binascii . a2b_hex ( 'a6a6a6a6a6a6a6a6 ' )
37+ RFC5649_IV = binascii . a2b_hex ( 'a65959a6 ' )
3638
3739
3840def wrap (plaintext , key , iv = None , pad = None ):
@@ -54,7 +56,7 @@ def wrap(plaintext, key, iv=None, pad=None):
5456 raise EncryptionError ('Plaintext length wrong' )
5557 if mli % 8 != 0 and pad is not False :
5658 r = (mli + 7 ) // 8
57- plaintext += ((r * 8 ) - mli ) * '\0 '
59+ plaintext += ((r * 8 ) - mli ) * b '\0 '
5860
5961 if iv is None :
6062 if len (plaintext ) != mli or pad is True :
@@ -63,7 +65,7 @@ def wrap(plaintext, key, iv=None, pad=None):
6365 iv = RFC3394_IV
6466
6567 encrypt = AES .new (key ).encrypt
66- n = len (plaintext ) / 8
68+ n = len (plaintext ) // 8
6769
6870 if n == 1 :
6971 # RFC 5649 shortcut
@@ -76,7 +78,7 @@ def wrap(plaintext, key, iv=None, pad=None):
7678 for i in range (n ):
7779 A , R [i ] = _split (encrypt (A + R [i ]))
7880 A = strxor (A , long_to_bytes (n * j + i + 1 , 8 ))
79- return A + '' .join (R )
81+ return A + b '' .join (R )
8082
8183
8284def unwrap (ciphertext , key , iv = None , pad = None ):
@@ -95,7 +97,7 @@ def unwrap(ciphertext, key, iv=None, pad=None):
9597 raise DecryptionError ('Ciphertext length wrong' )
9698
9799 decrypt = AES .new (key ).decrypt
98- n = len (ciphertext ) / 8 - 1
100+ n = len (ciphertext ) // 8 - 1
99101
100102 if n == 1 :
101103 A , plaintext = _split (decrypt (ciphertext ))
@@ -107,16 +109,16 @@ def unwrap(ciphertext, key, iv=None, pad=None):
107109 for i in reversed (range (n )):
108110 A = strxor (A , long_to_bytes (n * j + i + 1 , 8 ))
109111 A , R [i ] = _split (decrypt (A + R [i ]))
110- plaintext = '' .join (R )
112+ plaintext = b '' .join (R )
111113
112114 if iv is None :
113115 if A == RFC3394_IV and pad is not True :
114116 return plaintext
115117 elif A [:4 ] == RFC5649_IV and pad is not False :
116118 mli = bytes_to_long (A [4 :])
117- # check padding length is valid and only contains zeros
119+ # check padding length is valid and plaintext only contains zeros
118120 if 8 * (n - 1 ) < mli <= 8 * n and \
119- all ( x == '\0 ' for x in plaintext [ mli :] ):
121+ plaintext . endswith (( len ( plaintext ) - mli ) * b '\0 ' ):
120122 return plaintext [:mli ]
121123 elif A == iv :
122124 return plaintext
0 commit comments