1
+ """ SASL transports for Thrift. """
2
+
3
+ from thrift .transport .TTransport import CReadableTransport , TTransportBase , TTransportException , StringIO
4
+ import struct
5
+
6
+ class TSaslClientTransport (TTransportBase , CReadableTransport ):
7
+ START = 1
8
+ OK = 2
9
+ BAD = 3
10
+ ERROR = 4
11
+ COMPLETE = 5
12
+
13
+ def __init__ (self , sasl_client_factory , mechanism , trans ):
14
+ """
15
+ @param sasl_client_factory: a callable that returns a new sasl.Client object
16
+ @param mechanism: the SASL mechanism (e.g. "GSSAPI", "PLAIN")
17
+ @param trans: the underlying transport over which to communicate.
18
+ """
19
+ self ._trans = trans
20
+ self .sasl_client_factory = sasl_client_factory
21
+ self .sasl = None
22
+ self .mechanism = mechanism
23
+ self .__wbuf = StringIO ()
24
+ self .__rbuf = StringIO ()
25
+ self .opened = False
26
+ self .encode = None
27
+
28
+ def isOpen (self ):
29
+ return self ._trans .isOpen ()
30
+
31
+ def open (self ):
32
+ if not self ._trans .isOpen ():
33
+ self ._trans .open ()
34
+
35
+ if self .sasl is not None :
36
+ raise TTransportException (
37
+ type = TTransportException .NOT_OPEN ,
38
+ message = "Already open!" )
39
+ self .sasl = self .sasl_client_factory
40
+
41
+ ret , chosen_mech , initial_response = self .sasl .start (self .mechanism )
42
+ if not ret :
43
+ raise TTransportException (type = TTransportException .NOT_OPEN ,
44
+ message = ("Could not start SASL: %s" % self .sasl .getError ()))
45
+
46
+ # Send initial response
47
+ self ._send_message (self .START , chosen_mech )
48
+ self ._send_message (self .OK , initial_response )
49
+
50
+ # SASL negotiation loop
51
+ while True :
52
+ status , payload = self ._recv_sasl_message ()
53
+ if status not in (self .OK , self .COMPLETE ):
54
+ raise TTransportException (type = TTransportException .NOT_OPEN ,
55
+ message = ("Bad status: %d (%s)" % (status , payload )))
56
+ if status == self .COMPLETE :
57
+ break
58
+ ret , response = self .sasl .step (payload )
59
+ if not ret :
60
+ raise TTransportException (type = TTransportException .NOT_OPEN ,
61
+ message = ("Bad SASL result: %s" % (self .sasl .getError ())))
62
+ self ._send_message (self .OK , response )
63
+
64
+ def _send_message (self , status , body ):
65
+ header = struct .pack (">BI" , status , len (body ))
66
+ self ._trans .write (header + body )
67
+ self ._trans .flush ()
68
+
69
+ def _recv_sasl_message (self ):
70
+ header = self ._trans .readAll (5 )
71
+ status , length = struct .unpack (">BI" , header )
72
+ if length > 0 :
73
+ payload = self ._trans .readAll (length )
74
+ else :
75
+ payload = ""
76
+ return status , payload
77
+
78
+ def write (self , data ):
79
+ self .__wbuf .write (data )
80
+
81
+ def flush (self ):
82
+ buffer = self .__wbuf .getvalue ()
83
+ # The first time we flush data, we send it to sasl.encode()
84
+ # If the length doesn't change, then we must be using a QOP
85
+ # of auth and we should no longer call sasl.encode(), otherwise
86
+ # we encode every time.
87
+ if self .encode == None :
88
+ success , encoded = self .sasl .encode (buffer )
89
+ if not success :
90
+ raise TTransportException (type = TTransportException .UNKNOWN ,
91
+ message = self .sasl .getError ())
92
+ if (len (encoded )== len (buffer )):
93
+ self .encode = False
94
+ self ._flushPlain (buffer )
95
+ else :
96
+ self .encode = True
97
+ self ._trans .write (encoded )
98
+ elif self .encode :
99
+ self ._flushEncoded (buffer )
100
+ else :
101
+ self ._flushPlain (buffer )
102
+
103
+ self ._trans .flush ()
104
+ self .__wbuf = StringIO ()
105
+
106
+ def _flushEncoded (self , buffer ):
107
+ # sasl.ecnode() does the encoding and adds the length header, so nothing
108
+ # to do but call it and write the result.
109
+ success , encoded = self .sasl .encode (buffer )
110
+ if not success :
111
+ raise TTransportException (type = TTransportException .UNKNOWN ,
112
+ message = self .sasl .getError ())
113
+ self ._trans .write (encoded )
114
+
115
+ def _flushPlain (self , buffer ):
116
+ # When we have QOP of auth, sasl.encode() will pass the input to the output
117
+ # but won't put a length header, so we have to do that.
118
+
119
+ # Note stolen from TFramedTransport:
120
+ # N.B.: Doing this string concatenation is WAY cheaper than making
121
+ # two separate calls to the underlying socket object. Socket writes in
122
+ # Python turn out to be REALLY expensive, but it seems to do a pretty
123
+ # good job of managing string buffer operations without excessive copies
124
+ self ._trans .write (struct .pack (">I" , len (buffer )) + buffer )
125
+
126
+ def read (self , sz ):
127
+ ret = self .__rbuf .read (sz )
128
+ if len (ret ) != 0 :
129
+ return ret
130
+
131
+ self ._read_frame ()
132
+ return self .__rbuf .read (sz )
133
+
134
+ def _read_frame (self ):
135
+ header = self ._trans .readAll (4 )
136
+ (length ,) = struct .unpack (">I" , header )
137
+ if self .encode :
138
+ # If the frames are encoded (i.e. you're using a QOP of auth-int or
139
+ # auth-conf), then make sure to include the header in the bytes you send to
140
+ # sasl.decode()
141
+ encoded = header + self ._trans .readAll (length )
142
+ success , decoded = self .sasl .decode (encoded )
143
+ if not success :
144
+ raise TTransportException (type = TTransportException .UNKNOWN ,
145
+ message = self .sasl .getError ())
146
+ else :
147
+ # If the frames are not encoded, just pass it through
148
+ decoded = self ._trans .readAll (length )
149
+ self .__rbuf = StringIO (decoded )
150
+
151
+ def close (self ):
152
+ self ._trans .close ()
153
+ self .sasl = None
154
+
155
+ # Implement the CReadableTransport interface.
156
+ # Stolen shamelessly from TFramedTransport
157
+ @property
158
+ def cstringio_buf (self ):
159
+ return self .__rbuf
160
+
161
+ def cstringio_refill (self , prefix , reqlen ):
162
+ # self.__rbuf will already be empty here because fastbinary doesn't
163
+ # ask for a refill until the previous buffer is empty. Therefore,
164
+ # we can start reading new frames immediately.
165
+ while len (prefix ) < reqlen :
166
+ self ._read_frame ()
167
+ prefix += self .__rbuf .getvalue ()
168
+ self .__rbuf = StringIO (prefix )
169
+ return self .__rbuf
0 commit comments