Skip to content

Commit 1c1e507

Browse files
author
ricky
committed
Adding Kerberos Support via TSaslClientTransport
- Added a "sasl" dependency to requirements - Added working TSaslClientTransport - 2 (optional) arguments were added to the Connection class: :use_kerberos | signals to use secure authentication :sasl_service | name of the SASL service (default: hbase)
1 parent 9cbd718 commit 1c1e507

File tree

3 files changed

+198
-2
lines changed

3 files changed

+198
-2
lines changed

happybase/connection.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
from thrift.transport.TSocket import TSocket
1010
from thrift.transport.TTransport import TBufferedTransport, TFramedTransport
1111
from thrift.protocol import TBinaryProtocol, TCompactProtocol
12+
import sasl
13+
from os import path
1214

1315
from .hbase import Hbase
1416
from .hbase.ttypes import ColumnDescriptor
1517
from .table import Table
1618
from .util import pep8_to_camel_case
19+
from .thrift_sasl import TSaslClientTransport
1720

1821
logger = logging.getLogger(__name__)
1922

@@ -81,6 +84,11 @@ class Connection(object):
8184
process as well. ``TBinaryAccelerated`` is the default protocol that
8285
happybase uses.
8386
87+
The optional `use_kerberos` argument allows you to establish a
88+
secure connection to HBase. This argument requires a buffered
89+
`transport` protocol. You must first authorize yourself with
90+
your KDC by using kinit (e.g. kinit -kt my.keytab user@REALM)
91+
8492
.. versionadded:: 0.9
8593
`protocol` argument
8694
@@ -101,11 +109,14 @@ class Connection(object):
101109
:param str table_prefix_separator: Separator used for `table_prefix`
102110
:param str compat: Compatibility mode (optional)
103111
:param str transport: Thrift transport mode (optional)
112+
:param bool use_kerberos: Connect to HBase via a secure connection (default: False)
113+
:param str sasl_service: The name of the SASL service (default: hbase)
104114
"""
105115
def __init__(self, host=DEFAULT_HOST, port=DEFAULT_PORT, timeout=None,
106116
autoconnect=True, table_prefix=None,
107117
table_prefix_separator='_', compat=DEFAULT_COMPAT,
108-
transport=DEFAULT_TRANSPORT, protocol=DEFAULT_PROTOCOL):
118+
transport=DEFAULT_TRANSPORT, protocol=DEFAULT_PROTOCOL,
119+
use_kerberos=False, sasl_service="hbase"):
109120

110121
if transport not in THRIFT_TRANSPORTS:
111122
raise ValueError("'transport' must be one of %s"
@@ -135,6 +146,8 @@ def __init__(self, host=DEFAULT_HOST, port=DEFAULT_PORT, timeout=None,
135146
self.table_prefix_separator = table_prefix_separator
136147
self.compat = compat
137148

149+
self._use_kerberos = use_kerberos
150+
self._sasl_service = sasl_service
138151
self._transport_class = THRIFT_TRANSPORTS[transport]
139152
self._protocol_class = THRIFT_PROTOCOLS[protocol]
140153
self._refresh_thrift_client()
@@ -150,7 +163,20 @@ def _refresh_thrift_client(self):
150163
if self.timeout is not None:
151164
socket.setTimeout(self.timeout)
152165

153-
self.transport = self._transport_class(socket)
166+
if not self._use_kerberos:
167+
self.transport = self._transport_class(socket)
168+
else:
169+
# Check for required arguments for kerberos
170+
if self._transport_class is not TBufferedTransport:
171+
raise ValueError("Must use a buffered transport "
172+
" when use_kerberos is enabled")
173+
174+
saslc = sasl.Client()
175+
saslc.setAttr("host", self.host)
176+
saslc.setAttr("service", self._sasl_service)
177+
saslc.init()
178+
self.transport = TSaslClientTransport(saslc, "GSSAPI", socket)
179+
154180
protocol = self._protocol_class(self.transport)
155181
self.client = Hbase.Client(protocol)
156182

happybase/thrift_sasl.py

+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
""" SASL transports for Thrift. """
2+
3+
from thrift.transport.TTransport import CReadableTransport, TTransportBase, TTransportException, StringIO
4+
import struct
5+
6+
class TSaslClientTransport(TTransportBase, CReadableTransport):
7+
START = 1
8+
OK = 2
9+
BAD = 3
10+
ERROR = 4
11+
COMPLETE = 5
12+
13+
def __init__(self, sasl_client_factory, mechanism, trans):
14+
"""
15+
@param sasl_client_factory: a callable that returns a new sasl.Client object
16+
@param mechanism: the SASL mechanism (e.g. "GSSAPI", "PLAIN")
17+
@param trans: the underlying transport over which to communicate.
18+
"""
19+
self._trans = trans
20+
self.sasl_client_factory = sasl_client_factory
21+
self.sasl = None
22+
self.mechanism = mechanism
23+
self.__wbuf = StringIO()
24+
self.__rbuf = StringIO()
25+
self.opened = False
26+
self.encode = None
27+
28+
def isOpen(self):
29+
return self._trans.isOpen()
30+
31+
def open(self):
32+
if not self._trans.isOpen():
33+
self._trans.open()
34+
35+
if self.sasl is not None:
36+
raise TTransportException(
37+
type=TTransportException.NOT_OPEN,
38+
message="Already open!")
39+
self.sasl = self.sasl_client_factory
40+
41+
ret, chosen_mech, initial_response = self.sasl.start(self.mechanism)
42+
if not ret:
43+
raise TTransportException(type=TTransportException.NOT_OPEN,
44+
message=("Could not start SASL: %s" % self.sasl.getError()))
45+
46+
# Send initial response
47+
self._send_message(self.START, chosen_mech)
48+
self._send_message(self.OK, initial_response)
49+
50+
# SASL negotiation loop
51+
while True:
52+
status, payload = self._recv_sasl_message()
53+
if status not in (self.OK, self.COMPLETE):
54+
raise TTransportException(type=TTransportException.NOT_OPEN,
55+
message=("Bad status: %d (%s)" % (status, payload)))
56+
if status == self.COMPLETE:
57+
break
58+
ret, response = self.sasl.step(payload)
59+
if not ret:
60+
raise TTransportException(type=TTransportException.NOT_OPEN,
61+
message=("Bad SASL result: %s" % (self.sasl.getError())))
62+
self._send_message(self.OK, response)
63+
64+
def _send_message(self, status, body):
65+
header = struct.pack(">BI", status, len(body))
66+
self._trans.write(header + body)
67+
self._trans.flush()
68+
69+
def _recv_sasl_message(self):
70+
header = self._trans.readAll(5)
71+
status, length = struct.unpack(">BI", header)
72+
if length > 0:
73+
payload = self._trans.readAll(length)
74+
else:
75+
payload = ""
76+
return status, payload
77+
78+
def write(self, data):
79+
self.__wbuf.write(data)
80+
81+
def flush(self):
82+
buffer = self.__wbuf.getvalue()
83+
# The first time we flush data, we send it to sasl.encode()
84+
# If the length doesn't change, then we must be using a QOP
85+
# of auth and we should no longer call sasl.encode(), otherwise
86+
# we encode every time.
87+
if self.encode == None:
88+
success, encoded = self.sasl.encode(buffer)
89+
if not success:
90+
raise TTransportException(type=TTransportException.UNKNOWN,
91+
message=self.sasl.getError())
92+
if (len(encoded)==len(buffer)):
93+
self.encode = False
94+
self._flushPlain(buffer)
95+
else:
96+
self.encode = True
97+
self._trans.write(encoded)
98+
elif self.encode:
99+
self._flushEncoded(buffer)
100+
else:
101+
self._flushPlain(buffer)
102+
103+
self._trans.flush()
104+
self.__wbuf = StringIO()
105+
106+
def _flushEncoded(self, buffer):
107+
# sasl.ecnode() does the encoding and adds the length header, so nothing
108+
# to do but call it and write the result.
109+
success, encoded = self.sasl.encode(buffer)
110+
if not success:
111+
raise TTransportException(type=TTransportException.UNKNOWN,
112+
message=self.sasl.getError())
113+
self._trans.write(encoded)
114+
115+
def _flushPlain(self, buffer):
116+
# When we have QOP of auth, sasl.encode() will pass the input to the output
117+
# but won't put a length header, so we have to do that.
118+
119+
# Note stolen from TFramedTransport:
120+
# N.B.: Doing this string concatenation is WAY cheaper than making
121+
# two separate calls to the underlying socket object. Socket writes in
122+
# Python turn out to be REALLY expensive, but it seems to do a pretty
123+
# good job of managing string buffer operations without excessive copies
124+
self._trans.write(struct.pack(">I", len(buffer)) + buffer)
125+
126+
def read(self, sz):
127+
ret = self.__rbuf.read(sz)
128+
if len(ret) != 0:
129+
return ret
130+
131+
self._read_frame()
132+
return self.__rbuf.read(sz)
133+
134+
def _read_frame(self):
135+
header = self._trans.readAll(4)
136+
(length,) = struct.unpack(">I", header)
137+
if self.encode:
138+
# If the frames are encoded (i.e. you're using a QOP of auth-int or
139+
# auth-conf), then make sure to include the header in the bytes you send to
140+
# sasl.decode()
141+
encoded = header + self._trans.readAll(length)
142+
success, decoded = self.sasl.decode(encoded)
143+
if not success:
144+
raise TTransportException(type=TTransportException.UNKNOWN,
145+
message=self.sasl.getError())
146+
else:
147+
# If the frames are not encoded, just pass it through
148+
decoded = self._trans.readAll(length)
149+
self.__rbuf = StringIO(decoded)
150+
151+
def close(self):
152+
self._trans.close()
153+
self.sasl = None
154+
155+
# Implement the CReadableTransport interface.
156+
# Stolen shamelessly from TFramedTransport
157+
@property
158+
def cstringio_buf(self):
159+
return self.__rbuf
160+
161+
def cstringio_refill(self, prefix, reqlen):
162+
# self.__rbuf will already be empty here because fastbinary doesn't
163+
# ask for a refill until the previous buffer is empty. Therefore,
164+
# we can start reading new frames immediately.
165+
while len(prefix) < reqlen:
166+
self._read_frame()
167+
prefix += self.__rbuf.getvalue()
168+
self.__rbuf = StringIO(prefix)
169+
return self.__rbuf

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
thrift>=0.8.0
2+
sasl

0 commit comments

Comments
 (0)