#include "provider.h"
WSPUPCALLTABLE MainUpCallTable;
DWORD gLayerCatId = 0;
DWORD gChainId = 0;
DWORD gEntryCount = 0;
CRITICAL_SECTION gCriticalSection;
LPWSPDATA gWSPData = NULL;
WSPPROC_TABLE NextProcTable;
LPWSPPROC_TABLE gProcTable = NULL;
LPWSAPROTOCOL_INFOW gBaseInfo = NULL;
HINSTANCE HDllInstance = NULL;
HINSTANCE hProvider = NULL;
INT gLayerCount=0; // Number of base providers we're layered over
static TCHAR Msg[512];
BOOL WINAPI DllMain(IN HINSTANCE hinstDll, IN DWORD dwReason, LPVOID lpvReserved)
{
switch (dwReason)
{
case DLL_PROCESS_ATTACH:
HDllInstance = hinstDll;
InitializeCriticalSection(&gCriticalSection);
// InitAsyncSelectCS();
InitOverlappedCS();
break;
case DLL_THREAD_ATTACH:
break;
case DLL_THREAD_DETACH:
break;
case DLL_PROCESS_DETACH:
break;
}
return TRUE;
}
SOCKET WSPAPI WSPAccept (
SOCKET s,
struct sockaddr FAR * addr,
LPINT addrlen,
LPCONDITIONPROC lpfnCondition,
DWORD dwCallbackData,
LPINT lpErrno)
{
SOCKET NewProviderSocket;
SOCKET NewSocket;
SOCK_INFO *NewSocketContext;
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return INVALID_SOCKET;
NewProviderSocket = NextProcTable.lpWSPAccept(SocketContext->ProviderSocket, addr, addrlen,
lpfnCondition, dwCallbackData, lpErrno);
if (NewProviderSocket != INVALID_SOCKET)
{
if ((NewSocketContext = (SOCK_INFO *) GlobalAlloc(GPTR, sizeof SOCK_INFO)) == NULL)
{
*lpErrno = WSAENOBUFS;
return INVALID_SOCKET;
}
NewSocketContext->ProviderSocket = NewProviderSocket;
NewSocketContext->bClosing = FALSE;
NewSocketContext->dwOutstandingAsync = 0;
NewSocketContext->BytesRecv = 0;
NewSocketContext->BytesSent = 0;
if ((NewSocket = MainUpCallTable.lpWPUCreateSocketHandle(gChainId, (DWORD) NewSocketContext, lpErrno)) != INVALID_SOCKET)
DuplicateAsyncSocket(SocketContext->ProviderSocket, NewProviderSocket, NewSocket);
{
TCHAR buffer[128];
wsprintf(buffer, L"Creating socket %d\n", NewSocket);
OutputDebugString(buffer);
}
return NewSocket;
}
return INVALID_SOCKET;
}
int WSPAPI WSPAddressToString(
LPSOCKADDR lpsaAddress,
DWORD dwAddressLength,
LPWSAPROTOCOL_INFOW lpProtocolInfo,
LPWSTR lpszAddressString,
LPDWORD lpdwAddressStringLength,
LPINT lpErrno)
{
return NextProcTable.lpWSPAddressToString(lpsaAddress, dwAddressLength,
&gBaseInfo[0], lpszAddressString, lpdwAddressStringLength, lpErrno);
}
int WSPAPI WSPAsyncSelect (
SOCKET s,
HWND hWnd,
unsigned int wMsg,
long lEvent,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
HWND hWorkerWindow;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
if ((hWorkerWindow = SetWorkerWindow(SocketContext->ProviderSocket, s, hWnd, wMsg)) == NULL)
return SOCKET_ERROR;
return NextProcTable.lpWSPAsyncSelect(SocketContext->ProviderSocket, hWorkerWindow, WM_SOCKET, lEvent, lpErrno);
}
int WSPAPI WSPBind(
SOCKET s,
const struct sockaddr FAR * name,
int namelen,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPBind(SocketContext->ProviderSocket, name, namelen, lpErrno);
}
int WSPAPI WSPCancelBlockingCall(
LPINT lpErrno)
{
return NextProcTable.lpWSPCancelBlockingCall(lpErrno);
}
int WSPAPI WSPCleanup (
LPINT lpErrno
)
{
int Ret;
if (!gEntryCount)
{
*lpErrno = WSANOTINITIALISED;
return SOCKET_ERROR;
}
Ret = NextProcTable.lpWSPCleanup(lpErrno);
EnterCriticalSection(&gCriticalSection);
gEntryCount--;
if (gEntryCount == 0)
{
FreeLibrary(hProvider);
hProvider = NULL;
}
LeaveCriticalSection(&gCriticalSection);
return Ret;
}
int WSPAPI WSPCloseSocket (
SOCKET s,
LPINT lpErrno
)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
if (SocketContext->dwOutstandingAsync != 0)
{
SocketContext->bClosing = TRUE;
if (NextProcTable.lpWSPCloseSocket(SocketContext->ProviderSocket, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return 0;
}
if (NextProcTable.lpWSPCloseSocket(SocketContext->ProviderSocket, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
RemoveSockInfo(SocketContext->ProviderSocket);
if (MainUpCallTable.lpWPUCloseSocketHandle(s, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
{
TCHAR buffer[128];
wsprintf(buffer, L"Closing socket %d Bytes Sent [%lu] Bytes Recv [%lu]\n", s,
SocketContext->BytesSent, SocketContext->BytesRecv);
OutputDebugString(buffer);
}
GlobalFree(SocketContext);
return 0;
}
int WSPAPI WSPConnect (
SOCKET s,
const struct sockaddr FAR * name,
int namelen,
LPWSABUF lpCallerData,
LPWSABUF lpCalleeData,
LPQOS lpSQOS,
LPQOS lpGQOS,
LPINT lpErrno
)
{
SOCK_INFO *SocketContext;
INT ret;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
{
return SOCKET_ERROR;
}
ret = NextProcTable.lpWSPConnect(SocketContext->ProviderSocket, name, namelen, lpCallerData, lpCalleeData,
lpSQOS, lpGQOS, lpErrno);
return ret;
}
int WSPAPI WSPDuplicateSocket(
SOCKET s,
DWORD dwProcessId,
LPWSAPROTOCOL_INFOW lpProtocolInfo,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPDuplicateSocket(SocketContext->ProviderSocket,
dwProcessId, lpProtocolInfo, lpErrno);
}
int WSPAPI WSPEnumNetworkEvents(
SOCKET s,
WSAEVENT hEventObject,
LPWSANETWORKEVENTS lpNetworkEvents,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPEnumNetworkEvents(SocketContext->ProviderSocket,
hEventObject, lpNetworkEvents, lpErrno);
}
int WSPAPI WSPEventSelect(
SOCKET s,
WSAEVENT hEventObject,
long lNetworkEvents,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPEventSelect(SocketContext->ProviderSocket, hEventObject,