Skip to content

Expose open/close for connection pooling #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
dist/
dist-newstyle/
.cabal-sandbox/
cabal.sandbox.config
node_modules
Expand Down
266 changes: 167 additions & 99 deletions src/Ldap/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | This module is intended to be imported qualified
--
-- @
-- import qualified Ldap.Client as Ldap
-- @
module Ldap.Client
( with
, with'
, runsIn
, runsInEither
, open
, close
, Host(..)
, defaultTlsSettings
, insecureTlsSettings
, PortNumber
, Ldap
, LdapH
, LdapError(..)
, ResponseError(..)
, Type.ResultCode(..)
Expand Down Expand Up @@ -66,8 +74,9 @@ import qualified Control.Concurrent.Async as Async
import Control.Concurrent.STM (atomically, throwSTM)
import Control.Concurrent.STM.TMVar (putTMVar)
import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue)
import Control.Exception (Exception, Handler(..), bracket, throwIO, catch, catches)
import Control.Exception (Exception, bracket, throwIO, SomeException, fromException, throw, Handler(..))
import Control.Monad (forever)
import Data.Void (Void)
import qualified Data.ASN1.BinaryEncoding as Asn1
import qualified Data.ASN1.Encoding as Asn1
import qualified Data.ASN1.Error as Asn1
Expand Down Expand Up @@ -114,50 +123,100 @@ import Ldap.Client.Extended (Oid(..), extended, noticeOfDisconnectionO
{-# ANN module ("HLint: ignore Use first" :: String) #-}


newLdap :: IO Ldap
newLdap = Ldap
<$> newTQueueIO

-- | Various failures that can happen when working with LDAP.
data LdapError =
IOError !IOError -- ^ Network failure.
data LdapError
= IOError !IOError -- ^ Network failure.
| ParseError !Asn1.ASN1Error -- ^ Invalid ASN.1 data received from the server.
| ResponseError !ResponseError -- ^ An LDAP operation failed.
| DisconnectError !Disconnect -- ^ Notice of Disconnection has been received.
deriving (Show, Eq)

newtype WrappedIOError = WrappedIOError IOError
deriving (Show, Eq, Typeable)

instance Exception WrappedIOError
instance Exception LdapError

data Disconnect = Disconnect !Type.ResultCode !Dn !Text
deriving (Show, Eq, Typeable)

instance Exception Disconnect

newtype LdapH = LdapH Ldap

-- | Provide a 'LdapH' to a function needing an 'Ldap' handle.
runsIn :: (Ldap -> IO a)
-> LdapH
-> IO a
runsIn act (LdapH ldap) = do
actor <- Async.async (act ldap)
r <- Async.waitEitherCatch (workers ldap) actor
case r of
Left (Right _a) -> error "Unreachable"
Left (Left e) -> throwIO =<< catchesHandler workerErr e
Right (Right r') -> pure r'
Right (Left e) -> throwIO =<< catchesHandler respErr e

-- | Provide a 'LdapH' to a function needing an 'Ldap' handle
runsInEither :: (Ldap -> IO a)
-> LdapH
-> IO (Either LdapError a)
runsInEither act (LdapH ldap) = do
actor <- Async.async (act ldap)
r <- Async.waitEitherCatch (workers ldap) actor
case r of
Left (Right _a) -> error "Unreachable"
Left (Left e) -> do Left <$> catchesHandler workerErr e
Right (Right r') -> pure (Right r')
Right (Left e) -> do Left <$> catchesHandler respErr e


workerErr :: [Handler LdapError]
workerErr = [ Handler (\(ex :: IOError) -> pure (IOError ex))
, Handler (\(ex :: Asn1.ASN1Error) -> pure (ParseError ex))
, Handler (\(ex :: Disconnect) -> pure (DisconnectError ex))
]

respErr :: [Handler LdapError]
respErr = [ Handler (\(ex :: ResponseError) -> pure (ResponseError ex))
]

catchesHandler :: [Handler a] -> SomeException -> IO a
catchesHandler handlers e = foldr tryHandler (throw e) handlers
where tryHandler (Handler handler) res
= case fromException e of
Just e' -> handler e'
Nothing -> res

-- | The entrypoint into LDAP.
with' :: Host -> PortNumber -> (Ldap -> IO a) -> IO a
with' host port act = bracket (open host port) close (runsIn act)

-- | The entrypoint into LDAP.
--
-- It catches all LDAP-related exceptions.
with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a)
with host port f = do
with host port act = bracket (open host port) close (runsInEither act)

-- | Creates an LDAP handle. This action is useful for creating your own resource
-- management, such as with 'resource-pool'. The handle must be manually closed
-- with 'close'.
open :: Host -> PortNumber -> IO (LdapH)
open host port = do
context <- Conn.initConnectionContext
bracket (Conn.connectTo context params) Conn.connectionClose (\conn ->
bracket newLdap unbindAsync (\l -> do
inq <- newTQueueIO
outq <- newTQueueIO
as <- traverse Async.async
[ input inq conn
, output outq conn
, dispatch l inq outq
, f l
]
fmap (Right . snd) (Async.waitAnyCancel as)))
`catches`
[ Handler (\(WrappedIOError e) -> return (Left (IOError e)))
, Handler (return . Left . ParseError)
, Handler (return . Left . ResponseError)
]
conn <- Conn.connectTo context params
reqQ <- newTQueueIO
inQ <- newTQueueIO
outQ <- newTQueueIO

-- The input worker that reads data off the network.
(inW :: Async.Async Void) <- Async.async (input inQ conn)

-- The output worker that sends data onto the network.
(outW :: Async.Async Void) <- Async.async (output outQ conn)

-- The dispatch worker that sends data between the three queues.
(dispW :: Async.Async Void) <- Async.async (dispatch reqQ inQ outQ)

-- We use this to propagate exceptions between the workers. The `workers` Async is just a tool to
-- exchange exceptions between the entire worker group and another thread.
workers <- Async.async (snd <$> Async.waitAnyCancel [inW, outW, dispW])

pure (LdapH (Ldap reqQ workers conn))
where
params = Conn.ConnectionParams
{ Conn.connectionHostname =
Expand All @@ -172,6 +231,14 @@ with host port f = do
, Conn.connectionUseSocks = Nothing
}

-- | Closes an LDAP connection.
-- This is to be used in together with 'open'.
close :: LdapH -> IO ()
close (LdapH ldap) = do
unbindAsync ldap
Conn.connectionClose (conn ldap)
Async.cancel (workers ldap)

defaultTlsSettings :: Conn.TLSSettings
defaultTlsSettings = Conn.TLSSettingsSimple
{ Conn.settingDisableCertificateValidation = False
Expand All @@ -186,84 +253,85 @@ insecureTlsSettings = Conn.TLSSettingsSimple
, Conn.settingUseServerName = False
}

-- | Reads Asn1 BER encoded chunks off a connection into a TQueue.
input :: FromAsn1 a => TQueue a -> Connection -> IO b
input inq conn = wrap . flip fix [] $ \loop chunks -> do
chunk <- Conn.connectionGet conn 8192
case ByteString.length chunk of
0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing)
_ -> do
let chunks' = chunk : chunks
case Asn1.decodeASN1 Asn1.BER (ByteString.Lazy.fromChunks (reverse chunks')) of
Left Asn1.ParsingPartial
-> loop chunks'
Left e -> throwIO e
Right asn1 -> do
flip fix asn1 $ \loop' asn1' ->
case parseAsn1 asn1' of
Nothing -> return ()
Just (asn1'', a) -> do
atomically (writeTQueue inq a)
loop' asn1''
loop []
input inq conn = loop []
where
loop chunks = do
chunk <- Conn.connectionGet conn 8192
case ByteString.length chunk of
0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing)
_ -> do
let chunks' = chunk : chunks
case Asn1.decodeASN1 Asn1.BER (ByteString.Lazy.fromChunks (reverse chunks')) of
Left Asn1.ParsingPartial
-> loop chunks'
Left e -> throwIO e
Right asn1 -> do
flip fix asn1 $ \loop' asn1' ->
case parseAsn1 asn1' of
Nothing -> return ()
Just (asn1'', a) -> do
atomically (writeTQueue inq a)
loop' asn1''
loop []

-- | Transmits Asn1 DER encoded data from a TQueue into a Connection.
output :: ToAsn1 a => TQueue a -> Connection -> IO b
output out conn = wrap . forever $ do
output out conn = forever $ do
msg <- atomically (readTQueue out)
Conn.connectionPut conn (encode (toAsn1 msg))
where
encode x = Asn1.encodeASN1' Asn1.DER (appEndo x [])

dispatch
:: Ldap
:: TQueue ClientMessage
-> TQueue (Type.LdapMessage Type.ProtocolServerOp)
-> TQueue (Type.LdapMessage Request)
-> IO a
dispatch Ldap { client } inq outq =
flip fix (Map.empty, 1) $ \loop (!req, !counter) ->
loop =<< atomically (asum
[ do New new var <- readTQueue client
writeTQueue outq (Type.LdapMessage (Type.Id counter) new Nothing)
return (Map.insert (Type.Id counter) ([], var) req, counter + 1)
, do Type.LdapMessage mid op _
<- readTQueue inq
res <- case op of
Type.BindResponse {} -> done mid op req
Type.SearchResultEntry {} -> saveUp mid op req
Type.SearchResultReference {} -> return req
Type.SearchResultDone {} -> done mid op req
Type.ModifyResponse {} -> done mid op req
Type.AddResponse {} -> done mid op req
Type.DeleteResponse {} -> done mid op req
Type.ModifyDnResponse {} -> done mid op req
Type.CompareResponse {} -> done mid op req
Type.ExtendedResponse {} -> probablyDisconnect mid op req
Type.IntermediateResponse {} -> saveUp mid op req
return (res, counter)
])
where
saveUp mid op res =
return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res)

done mid op req =
case Map.lookup mid req of
Nothing -> return req
Just (stack, var) -> do
putTMVar var (op :| stack)
return (Map.delete mid req)

probablyDisconnect (Type.Id 0)
(Type.ExtendedResponse
(Type.LdapResult code
(Type.LdapDn (Type.LdapString dn))
(Type.LdapString reason)
_)
moid _)
req =
case moid of
Just (Type.LdapOid oid)
| Oid oid == noticeOfDisconnectionOid -> throwSTM (Disconnect code (Dn dn) reason)
_ -> return req
probablyDisconnect mid op req = done mid op req

wrap :: IO a -> IO a
wrap m = m `catch` (throwIO . WrappedIOError)
dispatch reqq inq outq = loop (Map.empty, 1)
where
saveUp mid op res = return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res)

loop (!req, !counter) =
loop =<< atomically (asum
[ do New new var <- readTQueue reqq
writeTQueue outq (Type.LdapMessage (Type.Id counter) new Nothing)
return (Map.insert (Type.Id counter) ([], var) req, counter + 1)
, do Type.LdapMessage mid op _
<- readTQueue inq
res <- case op of
Type.BindResponse {} -> done mid op req
Type.SearchResultEntry {} -> saveUp mid op req
Type.SearchResultReference {} -> return req
Type.SearchResultDone {} -> done mid op req
Type.ModifyResponse {} -> done mid op req
Type.AddResponse {} -> done mid op req
Type.DeleteResponse {} -> done mid op req
Type.ModifyDnResponse {} -> done mid op req
Type.CompareResponse {} -> done mid op req
Type.ExtendedResponse {} -> probablyDisconnect mid op req
Type.IntermediateResponse {} -> saveUp mid op req
return (res, counter)
])

done mid op req =
case Map.lookup mid req of
Nothing -> return req
Just (stack, var) -> do
putTMVar var (op :| stack)
return (Map.delete mid req)

probablyDisconnect (Type.Id 0)
(Type.ExtendedResponse
(Type.LdapResult code
(Type.LdapDn (Type.LdapString dn))
(Type.LdapString reason)
_)
moid _)
req =
case moid of
Just (Type.LdapOid oid)
| Oid oid == noticeOfDisconnectionOid -> throwSTM (Disconnect code (Dn dn) reason)
_ -> return req
probablyDisconnect mid op req = done mid op req
2 changes: 1 addition & 1 deletion src/Ldap/Client/Add.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import Ldap.Client.Internal
-- | Perform the Add operation synchronously. Raises 'ResponseError' on failures.
add :: Ldap -> Dn -> AttrList NonEmpty -> IO ()
add l dn as =
raise =<< addEither l dn as
eitherToIO =<< addEither l dn as

-- | Perform the Add operation synchronously. Returns @Left e@ where
-- @e@ is a 'ResponseError' on failures.
Expand Down
4 changes: 2 additions & 2 deletions src/Ldap/Client/Bind.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ newtype Password = Password ByteString
-- | Perform the Bind operation synchronously. Raises 'ResponseError' on failures.
bind :: Ldap -> Dn -> Password -> IO ()
bind l username password =
raise =<< bindEither l username password
eitherToIO =<< bindEither l username password

-- | Perform the Bind operation synchronously. Returns @Left e@ where
-- @e@ is a 'ResponseError' on failures.
Expand Down Expand Up @@ -82,7 +82,7 @@ bindResult req res = Left (ResponseInvalid req res)
-- | Perform a SASL EXTERNAL Bind operation synchronously. Raises 'ResponseError' on failures.
externalBind :: Ldap -> Dn -> Maybe Text -> IO ()
externalBind l username mCredentials =
raise =<< externalBindEither l username mCredentials
eitherToIO =<< externalBindEither l username mCredentials

-- | Perform a SASL EXTERNAL Bind operation synchronously. Returns @Left e@ where
-- @e@ is a 'ResponseError' on failures.
Expand Down
2 changes: 1 addition & 1 deletion src/Ldap/Client/Compare.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import qualified Ldap.Asn1.Type as Type
-- | Perform the Compare operation synchronously. Raises 'ResponseError' on failures.
compare :: Ldap -> Dn -> Attr -> AttrValue -> IO Bool
compare l dn k v =
raise =<< compareEither l dn k v
eitherToIO =<< compareEither l dn k v

-- | Perform the Compare operation synchronously. Returns @Left e@ where
-- @e@ is a 'ResponseError' on failures.
Expand Down
2 changes: 1 addition & 1 deletion src/Ldap/Client/Delete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import Ldap.Client.Internal
-- | Perform the Delete operation synchronously. Raises 'ResponseError' on failures.
delete :: Ldap -> Dn -> IO ()
delete l dn =
raise =<< deleteEither l dn
eitherToIO =<< deleteEither l dn

-- | Perform the Delete operation synchronously. Returns @Left e@ where
-- @e@ is a 'ResponseError' on failures.
Expand Down
Loading