1181 lines
22 KiB
C++
Raw Normal View History

2001-01-01 00:00:00 +01:00
/*++
Copyright (c) 1994 Microsoft Corporation
Module Name:
iwinsock.cxx
Abstract:
Contains functions to load sockets DLL and entry points. Functions and data
in this module take care of indirecting sockets calls, hence _I_ in front
of the function name
Contents:
IwinsockInitialize
IwinsockTerminate
LoadWinsock
UnloadWinsock
SafeCloseSocket
Author:
Richard L Firth (rfirth) 12-Apr-1995
Environment:
Win32(s) user-mode DLL
Revision History:
12-Apr-1995 rfirth
Created
08-May-1996 arthurbi
Added support for Socks Firewalls.
05-Mar-1998 rfirth
Moved SOCKS support into ICSocket class. Removed SOCKS library
loading/unloading from this module (revert to pre-SOCKS)
--*/
#include <wininetp.h>
#if defined(__cplusplus)
extern "C" {
#endif
//#define RLF_DEBUG 1
#if INET_DEBUG
#ifdef RLF_DEBUG
#define DPRINTF dprintf
#else
#define DPRINTF (void)
#endif
BOOL
InitDebugSock(
VOID
);
VOID
TerminateDebugSock(
VOID
);
#else
#define DPRINTF (void)
#endif
//
// private types
//
typedef struct {
LPSTR FunctionOrdinal;
FARPROC * FunctionAddress;
} SOCKETS_FUNCTION;
//
// global data
//
GLOBAL
SOCKET
(PASCAL FAR * _I_accept)(
SOCKET s,
struct sockaddr FAR *addr,
int FAR *addrlen
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_bind)(
SOCKET s,
const struct sockaddr FAR *addr,
int namelen
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_closesocket)(
SOCKET s
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_connect)(
SOCKET s,
const struct sockaddr FAR *name,
int namelen
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_gethostname)(
char FAR * name,
int namelen
) = NULL;
GLOBAL
LPHOSTENT
(PASCAL FAR * _I_gethostbyname)(
LPSTR lpHostName
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_getsockname)(
SOCKET s,
struct sockaddr FAR *name,
int FAR * namelen
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_getsockopt)(
SOCKET s,
int level,
int optname,
char FAR * optval,
int FAR *optlen
);
GLOBAL
u_long
(PASCAL FAR * _I_htonl)(
u_long hostlong
) = NULL;
GLOBAL
u_short
(PASCAL FAR * _I_htons)(
u_short hostshort
) = NULL;
GLOBAL
unsigned long
(PASCAL FAR * _I_inet_addr)(
const char FAR * cp
) = NULL;
GLOBAL
char FAR *
(PASCAL FAR * _I_inet_ntoa)(
struct in_addr in
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_ioctlsocket)(
SOCKET s,
long cmd,
u_long FAR *argp
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_listen)(
SOCKET s,
int backlog
) = NULL;
GLOBAL
u_short
(PASCAL FAR * _I_ntohs)(
u_short netshort
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_recv)(
SOCKET s,
char FAR * buf,
int len,
int flags
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_WSARecv)(
SOCKET s,
LPWSABUF lpBuffers,
DWORD dwBufferCount,
LPDWORD lpNumberOfBytesRecvd,
LPDWORD lpFlags,
LPWSAOVERLAPPED lpOverlapped,
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_recvfrom)(
SOCKET s,
char FAR * buf,
int len,
int flags,
struct sockaddr FAR *from,
int FAR * fromlen
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_select)(
int nfds,
fd_set FAR *readfds,
fd_set FAR *writefds,
fd_set FAR *exceptfds,
const struct timeval FAR *timeout
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_send)(
SOCKET s,
const char FAR * buf,
int len,
int flags
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_WSASend)(
SOCKET s,
LPWSABUF lpBuffers,
DWORD dwBufferCount,
LPDWORD lpNumberOfBytesSent,
DWORD dwFlags,
LPWSAOVERLAPPED lpOverlapped,
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_sendto)(
SOCKET s,
const char FAR * buf,
int len,
int flags,
const struct sockaddr FAR *to,
int tolen
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_setsockopt)(
SOCKET s,
int level,
int optname,
const char FAR * optval,
int optlen
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_shutdown)(
SOCKET s,
int how
) = NULL;
GLOBAL
SOCKET
(PASCAL FAR * _I_socket)(
int af,
int type,
int protocol
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_WSAStartup)(
WORD wVersionRequired,
LPWSADATA lpWSAData
) = NULL;
GLOBAL
int
(PASCAL FAR * _I_WSACleanup)(
void
) = NULL;
//VENKATKBUG-remove later - for now trap any errors
GLOBAL
int
(PASCAL FAR * __I_WSAGetLastError)(
void
) = NULL;
int
___I_WSAGetLastError(
VOID
)
{
int nError = __I_WSAGetLastError();
/*
VENKATK_BUG - OK to have WSAENOTSOCK - could happen for timeout situations.
INET_ASSERT (nError != WSAENOTSOCK);
*/
return nError;
}
GLOBAL
int
(PASCAL FAR * _I_WSAGetLastError)(
void
) = ___I_WSAGetLastError;
GLOBAL
void
(PASCAL FAR * _I_WSASetLastError)(
int iError
) = NULL;
GLOBAL
int
(PASCAL FAR * _I___WSAFDIsSet)(
SOCKET,
fd_set FAR *
) = NULL;
#if INET_DEBUG
void SetupSocketsTracing(void);
#endif
//
// private data
//
//
// InitializationLock - protects against multiple threads loading WSOCK32.DLL
// and entry points
//
PRIVATE CCritSec InitializationLock;
//
// hWinsock - NULL when WSOCK32.DLL is not loaded
//
PRIVATE HINSTANCE hWinsock = NULL;
//
// WinsockLoadCount - the number of times we have made calls to LoadWinsock()
// and UnloadWinsock(). When this reaches 0 (again), we can unload the Winsock
// DLL for real
//
PRIVATE DWORD WinsockLoadCount = 0;
//
// SocketsFunctions - this is the list of entry points in WSOCK32.DLL that we
// need to load for WININET.DLL
//
PRIVATE
SOCKETS_FUNCTION
SocketsFunctions[] = {
"accept", (FARPROC*)&_I_accept,
"bind", (FARPROC*)&_I_bind,
"closesocket", (FARPROC*)&_I_closesocket,
"connect", (FARPROC*)&_I_connect,
"getsockname", (FARPROC*)&_I_getsockname,
"getsockopt", (FARPROC*)&_I_getsockopt,
"htonl", (FARPROC*)&_I_htonl,
"htons", (FARPROC*)&_I_htons,
"inet_addr", (FARPROC*)&_I_inet_addr,
"inet_ntoa", (FARPROC*)&_I_inet_ntoa,
"ioctlsocket", (FARPROC*)&_I_ioctlsocket,
"listen", (FARPROC*)&_I_listen,
"ntohs", (FARPROC*)&_I_ntohs,
"recv", (FARPROC*)&_I_recv,
"recvfrom", (FARPROC*)&_I_recvfrom,
"select", (FARPROC*)&_I_select,
"send", (FARPROC*)&_I_send,
"sendto", (FARPROC*)&_I_sendto,
"setsockopt", (FARPROC*)&_I_setsockopt,
"shutdown", (FARPROC*)&_I_shutdown,
"socket", (FARPROC*)&_I_socket,
"gethostbyname", (FARPROC*)&_I_gethostbyname,
"gethostname", (FARPROC*)&_I_gethostname,
"WSAGetLastError", (FARPROC*)&__I_WSAGetLastError,
"WSASetLastError", (FARPROC*)&_I_WSASetLastError,
"WSAStartup", (FARPROC*)&_I_WSAStartup,
"WSACleanup", (FARPROC*)&_I_WSACleanup,
"__WSAFDIsSet", (FARPROC*)&_I___WSAFDIsSet,
"WSARecv", (FARPROC*)&_I_WSARecv,
"WSASend", (FARPROC*)&_I_WSASend
};
//
// private prototypes
//
#if INET_DEBUG
void SetupSocketsTracing(void);
#endif
//
// functions
//
BOOL
IwinsockInitialize(
VOID
)
/*++
Routine Description:
Performs initialization/resource allocation for this module
Arguments:
None.
Return Value:
None.
--*/
{
BOOL fResult;
//
// initialize the critical section that protects against multiple threads
// trying to initialize Winsock
//
fResult = InitializationLock.Init();
#if INET_DEBUG
if (fResult)
fResult = InitDebugSock();
#endif
return fResult;
}
VOID
IwinsockTerminate(
VOID
)
/*++
Routine Description:
Performs termination & resource cleanup for this module
Arguments:
None.
Return Value:
None.
--*/
{
InitializationLock.FreeLock();
#if INET_DEBUG
TerminateDebugSock();
#endif
}
DWORD
LoadWinsock(
VOID
)
/*++
Routine Description:
Dynamically loads Windows sockets library
Arguments:
None.
Return Value:
DWORD
Success - ERROR_SUCCESS
Failure - Win32 error
e.g. LoadLibrary() failure
WSA error
e.g. WSAStartup() failure
--*/
{
DEBUG_ENTER((DBG_SOCKETS,
Dword,
"LoadWinsock",
NULL
));
DWORD error = ERROR_SUCCESS;
//
// ensure no 2 threads are trying to modify the loaded state of winsock at
// the same time
//
if (!InitializationLock.Lock())
{
error = ERROR_NOT_ENOUGH_MEMORY;
goto quit;
}
if (hWinsock == NULL) {
BOOL failed = FALSE;
//
// BUGBUG - read this value from registry
//
hWinsock = LoadLibrary("ws2_32");
if (hWinsock == NULL) {
DEBUG_PRINT(SOCKETS,
INFO,
("failed to load ws2_32.dll"));
hWinsock = LoadLibrary("wsock32");
}
if (hWinsock != NULL) {
//
// load the entry points
//
for (int i = 0; i < ARRAY_ELEMENTS(SocketsFunctions); ++i) {
FARPROC farProc;
farProc = GetProcAddress(
hWinsock,
(LPCSTR)SocketsFunctions[i].FunctionOrdinal
);
if (farProc == NULL) {
failed = TRUE;
break;
}
*SocketsFunctions[i].FunctionAddress = farProc;
}
if (!failed) {
//
// although we need a WSADATA for WSAStartup(), it is an
// expendible structure (not required for any other sockets
// calls)
//
WSADATA wsaData;
error = _I_WSAStartup(0x0101, &wsaData);
if (error == ERROR_SUCCESS) {
DEBUG_PRINT(SOCKETS,
INFO,
("winsock description: %q\n",
wsaData.szDescription
));
int stringLen;
stringLen = lstrlen(wsaData.szDescription);
if (strnistr(wsaData.szDescription, "novell", stringLen)
&& strnistr(wsaData.szDescription, "wsock32", stringLen)) {
DEBUG_PRINT(SOCKETS,
INFO,
("running on Novell Client32 stack\n"
));
GlobalRunningNovellClient32 = TRUE;
}
#if INET_DEBUG
SetupSocketsTracing();
#endif
} else {
failed = TRUE;
}
}
} else {
failed = TRUE;
}
//
// if we failed to find an entry point or WSAStartup() returned an error
// then unload the library
//
if (failed) {
//
// important: there should be no API calls between determining the
// failure and coming here to get the error code
//
// if error == ERROR_SUCCESS then we have to get the last error, else
// it is the error returned by WSAStartup()
//
if (error == ERROR_SUCCESS) {
error = GetLastError();
INET_ASSERT(error != ERROR_SUCCESS);
}
UnloadWinsock();
}
} else {
//
// just increment the number of times we have called LoadWinsock()
// without a corresponding call to UnloadWinsock();
//
++WinsockLoadCount;
}
InitializationLock.Unlock();
//
// if we failed for any reason, need to report that TCP/IP not available
//
if (error != ERROR_SUCCESS) {
error = ERROR_NOT_SUPPORTED;
}
quit:
DEBUG_LEAVE(error);
return error;
}
VOID
UnloadWinsock(
VOID
)
/*++
Routine Description:
Unloads winsock DLL and prepares hWinsock and SocketsFunctions[] for reload
Arguments:
None.
Return Value:
None.
--*/
{
DEBUG_ENTER((DBG_SOCKETS,
None,
"UnloadWinsock",
NULL
));
//
// ensure no 2 threads are trying to modify the loaded state of winsock at
// the same time
//
if (!InitializationLock.Lock())
{
goto quit;
}
//
// only unload the DLL if it has been mapped into process memory
//
if (hWinsock != NULL) {
//
// and only if this is the last load instance
//
if (WinsockLoadCount == 0) {
INET_ASSERT(_I_WSACleanup != NULL);
if (_I_WSACleanup != NULL) {
//
// need to terminate async support too - it is reliant on
// Winsock
//
//called only from LoadWinsock which is called only from INTERNET_HANDLE_OBJECT()
//so not in dynamic unload, so alrite to cleanup.
TerminateAsyncSupport(TRUE);
int serr = _I_WSACleanup();
if (serr != 0) {
DEBUG_PRINT(SOCKETS,
ERROR,
("WSACleanup() returns %d; WSA error = %d\n",
serr,
(_I_WSAGetLastError != NULL)
? _I_WSAGetLastError()
: -1
));
}
}
for (int i = 0; i < ARRAY_ELEMENTS(SocketsFunctions); ++i) {
*SocketsFunctions[i].FunctionAddress = (FARPROC)NULL;
}
FreeLibrary(hWinsock);
hWinsock = NULL;
} else {
//
// if there have been multiple virtual loads, then just reduce the
// load count
//
--WinsockLoadCount;
}
}
InitializationLock.Unlock();
quit:
DEBUG_LEAVE(0);
}
DWORD
SafeCloseSocket(
IN SOCKET Socket
)
/*++
Routine Description:
closesocket() call protected by exception handler in case winsock DLL has
been unloaded by system before Wininet DLL unloaded
Arguments:
Socket - socket handle to close
Return Value:
DWORD
Success - ERROR_SUCCESS
Failure - socket error mapped to ERROR_WINHTTP_ error
--*/
{
int serr;
__try {
serr = _I_closesocket(Socket);
} __except(EXCEPTION_EXECUTE_HANDLER) {
serr = 0;
}
ENDEXCEPT
return (serr == SOCKET_ERROR)
? MapInternetError(_I_WSAGetLastError())
: ERROR_SUCCESS;
}
CWrapOverlapped* GetWrapOverlappedObject(LPVOID lpAddress)
{
return CONTAINING_RECORD(lpAddress, CWrapOverlapped, m_Overlapped);
}
#if INET_DEBUG
//
// debug data types
//
SOCKET
PASCAL FAR
_II_socket(
int af,
int type,
int protocol
);
int
PASCAL FAR
_II_closesocket(
SOCKET s
);
SOCKET
PASCAL FAR
_II_accept(
SOCKET s,
struct sockaddr FAR *addr,
int FAR *addrlen
);
GLOBAL
SOCKET
(PASCAL FAR * _P_accept)(
SOCKET s,
struct sockaddr FAR *addr,
int FAR *addrlen
) = NULL;
GLOBAL
int
(PASCAL FAR * _P_closesocket)(
SOCKET s
) = NULL;
GLOBAL
SOCKET
(PASCAL FAR * _P_socket)(
int af,
int type,
int protocol
) = NULL;
#define MAX_STACK_TRACE 5
#define MAX_SOCK_ENTRIES 1000
typedef struct _DEBUG_SOCK_ENTRY {
SOCKET Socket;
DWORD StackTraceLength;
PVOID StackTrace[ MAX_STACK_TRACE ];
} DEBUG_SOCK_ENTRY, *LPDEBUG_SOCK_ENTRY;
CCritSec DebugSockLock;
DEBUG_SOCK_ENTRY GlobalSockEntry[MAX_SOCK_ENTRIES];
DWORD GlobalSocketsCount = 0;
#define LOCK_DEBUG_SOCK() (DebugSockLock.Lock())
#define UNLOCK_DEBUG_SOCK() (DebugSockLock.Unlock())
HINSTANCE NtDllHandle;
typedef USHORT (*RTL_CAPTURE_STACK_BACK_TRACE)(
IN ULONG FramesToSkip,
IN ULONG FramesToCapture,
OUT PVOID *BackTrace,
OUT PULONG BackTraceHash
);
RTL_CAPTURE_STACK_BACK_TRACE pRtlCaptureStackBackTrace;
BOOL
InitDebugSock(
VOID
)
{
memset( GlobalSockEntry, 0x0, sizeof(GlobalSockEntry) );
GlobalSocketsCount = 0;
if (!DebugSockLock.Init())
{
INET_ASSERT(FALSE);
return FALSE;
}
else
{
return TRUE;
}
}
VOID
TerminateDebugSock(
VOID
)
{
DebugSockLock.FreeLock();
}
VOID
SetupSocketsTracing(
VOID
)
{
if (!(InternetDebugCategoryFlags & DBG_TRACE_SOCKETS)) {
return ;
}
if (!IsPlatformWinNT()) {
return ;
}
if ((NtDllHandle = LoadLibrary("ntdll.dll")) == NULL) {
return ;
}
if ((pRtlCaptureStackBackTrace =
(RTL_CAPTURE_STACK_BACK_TRACE)
GetProcAddress(NtDllHandle, "RtlCaptureStackBackTrace")) == NULL) {
FreeLibrary(NtDllHandle);
return ;
}
//#ifdef DONT_DO_FOR_NOW
_P_accept = _I_accept;
_I_accept = _II_accept;
_P_closesocket = _I_closesocket;
_I_closesocket = _II_closesocket;
_P_socket = _I_socket;
_I_socket = _II_socket;
//#endif
}
VOID
AddSockEntry(
SOCKET S
)
{
DWORD i;
DWORD Hash;
if (!(InternetDebugCategoryFlags & DBG_TRACE_SOCKETS)) {
return ;
}
LOCK_DEBUG_SOCK();
//
// search for a free entry.
//
for( i = 0; i < MAX_SOCK_ENTRIES; i++ ) {
if( GlobalSockEntry[i].Socket == 0 ) {
DWORD Hash;
//
// found a free entry.
//
GlobalSockEntry[i].Socket = S;
//
// get caller stack.
//
#if i386
Hash = 0;
GlobalSockEntry[i].StackTraceLength =
pRtlCaptureStackBackTrace(
2,
MAX_STACK_TRACE,
GlobalSockEntry[i].StackTrace,
&Hash );
#else // i386
GlobalSockEntry[i].StackTraceLength = 0;
#endif // i386
GlobalSocketsCount++;
DEBUG_PRINT(SOCKETS,
INFO,
("socket count = %ld\n",
GlobalSocketsCount
));
DPRINTF("%d sockets\n", GlobalSocketsCount);
UNLOCK_DEBUG_SOCK();
return;
}
}
//
// we have reached a high handle limit, which is unusal, needs to be
// debugged.
//
INET_ASSERT( FALSE );
UNLOCK_DEBUG_SOCK();
return;
}
VOID
RemoveSockEntry(
SOCKET S
)
{
DWORD i;
if (!(InternetDebugCategoryFlags & DBG_TRACE_SOCKETS)) {
return ;
}
LOCK_DEBUG_SOCK();
for( i = 0; i < MAX_SOCK_ENTRIES; i++ ) {
if( GlobalSockEntry[i].Socket == S ) {
//
// found the entry. Free it now.
//
memset( &GlobalSockEntry[i], 0x0, sizeof(DEBUG_SOCK_ENTRY) );
GlobalSocketsCount--;
#ifdef IWINSOCK_DEBUG_PRINT
DEBUG_PRINT(SOCKETS,
INFO,
("count(%ld), RemoveSock(%lx)\n",
GlobalSocketsCount,
S
));
#endif // IWINSOCK_DEBUG_PRINT
DPRINTF("%d sockets\n", GlobalSocketsCount);
UNLOCK_DEBUG_SOCK();
return;
}
}
#ifdef IWINSOCK_DEBUG_PRINT
DEBUG_PRINT(SOCKETS,
INFO,
("count(%ld), UnknownSock(%lx)\n",
GlobalSocketsCount,
S
));
#endif // IWINSOCK_DEBUG_PRINT
//
// socket entry is not found.
//
// INET_ASSERT( FALSE );
UNLOCK_DEBUG_SOCK();
return;
}
SOCKET
PASCAL FAR
_II_socket(
int af,
int type,
int protocol
)
{
SOCKET S;
S = _P_socket( af, type, protocol );
AddSockEntry( S );
return( S );
}
int
PASCAL FAR
_II_closesocket(
SOCKET s
)
{
int Ret;
RemoveSockEntry( s );
Ret = _P_closesocket( s );
return( Ret );
}
SOCKET
PASCAL FAR
_II_accept(
SOCKET s,
struct sockaddr FAR *addr,
int FAR *addrlen
)
{
SOCKET S;
S = _P_accept( s, addr, addrlen );
AddSockEntry( S );
return( S );
}
VOID
IWinsockCheckSockets(
VOID
)
{
DEBUG_PRINT(SOCKETS,
INFO,
("GlobalSocketsCount = %d\n",
GlobalSocketsCount
));
for (DWORD i = 0; i < MAX_SOCK_ENTRIES; ++i) {
SOCKET sock;
if ((sock = GlobalSockEntry[i].Socket) != 0) {
DEBUG_PRINT(SOCKETS,
INFO,
("Socket %#x\n",
sock
));
}
}
}
#endif // INET_DEBUG
#if defined(__cplusplus)
}
#endif