Windows2003-3790/termsrv/tsuserex/creden.cpp

760 lines
16 KiB
C++

/*++
Copyright (c) 2001 Microsoft Corporation
Module Name:
creden.cpp
Abstract:
This module abstracts user credentials for the multiple credential support.
Author:
Rashmi Patankar (RashmiP) 10-Aug-2001
Revision History:
--*/
#include "stdafx.h"
#include <basetyps.h>
#include <des.h>
typedef long HRESULT;
#include "creden.h"
#include <ntsecapi.h>
typedef NTSTATUS SECURITY_STATUS;
UCHAR g_seed = 0 ;
#define BAIL_ON_FAILURE(hr) \
if (FAILED(hr)) \
{ \
goto error; \
}
//
// This routine allocates and stores the password in the
// passed in pointer. The assumption here is that pszString
// is valid, it can be an empty string but not NULL.
// Note that this code cannot be used as is on Win2k and below
// as they do not support the newer functions.
//
HRESULT
EncryptString(
LPWSTR pszString,
LPWSTR *ppszSafeString,
PDWORD pdwLen
)
{
HRESULT hr = S_OK;
DWORD dwLenStr = 0;
DWORD dwPwdLen = 0;
LPWSTR pszTempStr = NULL;
NTSTATUS errStatus = S_OK;
FNRTLINITUNICODESTRING pRtlInitUnicodeString = NULL;
FRTLRUNENCODEUNICODESTRING pRtlRunEncodeUnicodeString = NULL;
FRTLENCRYPTMEMORY pRtlEncryptMemory = NULL;
BOOLEAN GlobalUseScrambling = FALSE;
if (!pszString || !ppszSafeString)
{
return(E_FAIL);
}
*ppszSafeString = NULL;
*pdwLen = 0;
//
// If the string is valid, then we need to get the length
// and initialize the unicode string.
//
UNICODE_STRING Password;
//
// Determine the length of buffer taking padding into account.
//
dwLenStr = wcslen(pszString);
dwPwdLen = (dwLenStr + 1) * sizeof(WCHAR) + (DES_BLOCKLEN -1);
pszTempStr = (LPWSTR) AllocADsMem(dwPwdLen);
if (!pszTempStr)
{
hr = E_OUTOFMEMORY;
goto error;
}
wcscpy(pszTempStr, pszString);
if(g_ScramblingLibraryHandle)
{
pRtlInitUnicodeString = (FNRTLINITUNICODESTRING) GetProcAddress( g_ScramblingLibraryHandle, "RtlInitUnicodeString" );
}
if(!pRtlInitUnicodeString)
{
hr = E_FAIL;
goto error;
}
(*pRtlInitUnicodeString)(&Password, pszTempStr);
USHORT usExtra = 0;
if (usExtra = (Password.MaximumLength % DES_BLOCKLEN))
{
Password.MaximumLength += (DES_BLOCKLEN - usExtra);
}
*pdwLen = Password.MaximumLength;
if (g_AdvApi32LibraryHandle || g_ScramblingLibraryHandle)
{
GlobalUseScrambling = FALSE;
if (g_AdvApi32LibraryHandle)
{
//
// Try to get the advapi32.dll RtlEncryptMemory/RtlDecryptMemory functions,
// Note that RtlEncryptMemory and RtlDecryptMemory are really named
// SystemFunction040/041, hence the macros.
//
pRtlEncryptMemory = (FRTLENCRYPTMEMORY) GetProcAddress( g_AdvApi32LibraryHandle, (LPCSTR) 619 );
if (pRtlEncryptMemory)
{
// We want to use scrambling
GlobalUseScrambling = TRUE;
// Using strong scrambling
errStatus = (*pRtlEncryptMemory)( Password.Buffer,
Password.MaximumLength,
0
);
if (errStatus)
{
if(pszTempStr)
{
FreeADsMem(pszTempStr);
pszTempStr = NULL;
}
hr = HRESULT_FROM_NT(errStatus);
goto error;
}
}
else if (g_ScramblingLibraryHandle)
{
//
// Clean up so we can try falling back to the run-encode scrambling functions
// (we keep the AdvApi32LibraryHandle around since we'll probably need it
// later anyway)
//
pRtlRunEncodeUnicodeString = (FRTLRUNENCODEUNICODESTRING) GetProcAddress( g_ScramblingLibraryHandle, "RtlRunEncodeUnicodeString" );
if(_tcslen(Password.Buffer) && pRtlRunEncodeUnicodeString)
{
// encrypt password in place
(*pRtlRunEncodeUnicodeString)( &g_seed, &Password );
}
else
{
hr = E_FAIL;
goto error;
}
}
else
{
hr = E_FAIL;
goto error;
}
}
else
{
hr = E_FAIL;
goto error;
}
}
else
{
hr = E_FAIL;
goto error;
}
*ppszSafeString = pszTempStr;
error:
if (FAILED(hr) && pszTempStr)
{
FreeADsMem(pszTempStr);
pszTempStr = NULL;
}
return(hr);
}
HRESULT
DecryptString(
LPWSTR pszEncodedString,
LPWSTR *ppszString,
DWORD dwLen
)
{
HRESULT hr = E_FAIL;
LPWSTR pszTempStr = NULL;
NTSTATUS errStatus;
BOOLEAN GlobalUseScrambling = FALSE;
FNRTLINITUNICODESTRING pRtlInitUnicodeString = NULL;
FRTLRUNDECODEUNICODESTRING pRtlRunDecodeUnicodeString = NULL;
FRTLDECRYPTMEMORY pRtlDecryptMemory = NULL;
UNICODE_STRING UnicodePassword;
if (!dwLen || !ppszString)
{
return(E_FAIL);
}
*ppszString = NULL;
if (dwLen)
{
pszTempStr = (LPWSTR) AllocADsMem(dwLen);
if (!pszTempStr)
{
hr = E_OUTOFMEMORY;
goto error;
}
memcpy(pszTempStr, pszEncodedString, dwLen);
if (g_AdvApi32LibraryHandle || g_ScramblingLibraryHandle)
{
hr = S_OK;
GlobalUseScrambling = FALSE;
if (g_AdvApi32LibraryHandle)
{
//
// Try to get the advapi32.dll RtlEncryptMemory/RtlDecryptMemory functions,
// along with ntdll's RtlInitUnicodeString. Note that RtlEncryptMemory
// and RtlDecryptMemory are really named SystemFunction040/041, hence
// the macros.
//
pRtlDecryptMemory = (FRTLDECRYPTMEMORY) GetProcAddress( g_AdvApi32LibraryHandle, (LPCSTR) 620 );
if (pRtlDecryptMemory)
{
//
// We want to use scrambling
//
GlobalUseScrambling = TRUE;
// Using strong scrambling
errStatus = (*pRtlDecryptMemory)( pszTempStr,
dwLen,
0
);
if (errStatus)
{
if (NULL != pszTempStr)
{
FreeADsStr(pszTempStr);
pszTempStr = NULL;
}
hr = HRESULT_FROM_NT(errStatus);
goto error;
}
}
else if(g_ScramblingLibraryHandle)
{
//
// Clean up so we can try falling back to the run-encode scrambling functions
// (we keep the AdvApi32LibraryHandle around since we'll probably need it
// later anyway)
//
pRtlRunDecodeUnicodeString = (FRTLRUNDECODEUNICODESTRING) GetProcAddress( g_ScramblingLibraryHandle, "RtlRunDecodeUnicodeString" );
pRtlInitUnicodeString = (FNRTLINITUNICODESTRING) GetProcAddress( g_ScramblingLibraryHandle, "RtlInitUnicodeString" );
if(_tcslen(pszTempStr) && pRtlRunDecodeUnicodeString && pRtlInitUnicodeString)
{
(*pRtlInitUnicodeString)( &UnicodePassword, pszTempStr );
// encrypt password in place
(*pRtlRunDecodeUnicodeString)(g_seed, &UnicodePassword);
}
else
{
hr = E_FAIL;
goto error;
}
}
else
{
hr = E_FAIL;
goto error;
}
}
else
{
hr = E_FAIL;
goto error;
}
}
else
{
hr = E_FAIL;
goto error;
}
*ppszString = pszTempStr;
}
error:
if (FAILED(hr) && (NULL != pszTempStr))
{
FreeADsStr(pszTempStr);
pszTempStr = NULL;
}
return(hr);
}
//
// Static member of the class
//
CCredentials::CCredentials():
_lpszUserName(NULL),
_lpszPassword(NULL),
_dwAuthFlags(0),
_dwPasswordLen(0)
{
}
CCredentials::CCredentials(
LPWSTR lpszUserName,
LPWSTR lpszPassword,
DWORD dwAuthFlags
):
_lpszUserName(NULL),
_lpszPassword(NULL),
_dwAuthFlags(0),
_dwPasswordLen(0)
{
//
// AjayR 10-04-99 we need a way to bail if the
// alloc's fail. Since it is in the constructor this is
// not very easy to do.
//
if (lpszUserName)
{
_lpszUserName = AllocADsStr(lpszUserName);
}
else
{
_lpszUserName = NULL;
}
if (lpszPassword)
{
//
// The call can fail but we cannot recover from this.
//
EncryptString(
lpszPassword,
&_lpszPassword,
&_dwPasswordLen
);
}
else
{
_lpszPassword = NULL;
}
_dwAuthFlags = dwAuthFlags;
}
CCredentials::~CCredentials()
{
if (_lpszUserName)
{
FreeADsStr(_lpszUserName);
}
if (_lpszPassword)
{
FreeADsStr(_lpszPassword);
}
}
HRESULT
CCredentials::GetUserName(
LPWSTR *lppszUserName
)
{
if (!lppszUserName)
{
return(E_FAIL);
}
if (!_lpszUserName)
{
*lppszUserName = NULL;
}
else
{
*lppszUserName = AllocADsStr(_lpszUserName);
if (!*lppszUserName)
{
return(E_OUTOFMEMORY);
}
}
return(S_OK);
}
HRESULT
CCredentials::GetPassword(
LPWSTR * lppszPassword
)
{
if (!lppszPassword)
{
return(E_FAIL);
}
if (!_lpszPassword)
{
*lppszPassword = NULL;
}
else
{
return( DecryptString( _lpszPassword,
lppszPassword,
_dwPasswordLen
)
);
}
return(S_OK);
}
HRESULT
CCredentials::SetUserName(
LPWSTR lpszUserName
)
{
if (_lpszUserName)
{
FreeADsStr(_lpszUserName);
}
if (!lpszUserName)
{
_lpszUserName = NULL;
return(S_OK);
}
_lpszUserName = AllocADsStr(
lpszUserName
);
if(!_lpszUserName)
{
return(E_FAIL);
}
return(S_OK);
}
HRESULT
CCredentials::SetPassword(
LPWSTR lpszPassword
)
{
if (_lpszPassword)
{
FreeADsStr(_lpszPassword);
}
if (!lpszPassword)
{
_lpszPassword = NULL;
return(S_OK);
}
return( EncryptString( lpszPassword,
&_lpszPassword,
&_dwPasswordLen
)
);
}
CCredentials::CCredentials(
const CCredentials& Credentials
)
{
HRESULT hr = S_OK;
LPWSTR pszTmpPwd = NULL;
_lpszUserName = NULL;
_lpszPassword = NULL;
_lpszUserName = AllocADsStr(
Credentials._lpszUserName
);
if (Credentials._lpszPassword)
{
hr = DecryptString(
Credentials._lpszPassword,
&pszTmpPwd,
Credentials._dwPasswordLen
);
}
if (SUCCEEDED(hr) && pszTmpPwd)
{
hr = EncryptString(
pszTmpPwd,
&_lpszPassword,
&_dwPasswordLen
);
}
else
{
pszTmpPwd = NULL;
}
if (pszTmpPwd)
{
FreeADsStr(pszTmpPwd);
}
_dwAuthFlags = Credentials._dwAuthFlags;
}
void
CCredentials::operator=(
const CCredentials& other
)
{
HRESULT hr = S_OK;
LPWSTR pszTmpPwd = NULL;
if ( &other == this)
{
return;
}
if (_lpszUserName)
{
FreeADsStr(_lpszUserName);
}
if (_lpszPassword)
{
FreeADsStr(_lpszPassword);
}
_lpszUserName = AllocADsStr(
other._lpszUserName
);
if (other._lpszPassword)
{
hr = DecryptString(
other._lpszPassword,
&pszTmpPwd,
other._dwPasswordLen
);
}
if (SUCCEEDED(hr) && pszTmpPwd)
{
hr = EncryptString(
pszTmpPwd,
&_lpszPassword,
&_dwPasswordLen
);
}
else
{
pszTmpPwd = NULL;
}
if (pszTmpPwd)
{
FreeADsStr(pszTmpPwd);
}
_dwAuthFlags = other._dwAuthFlags;
return;
}
BOOL
operator==(
CCredentials& x,
CCredentials& y
)
{
BOOL bEqualUser = FALSE;
BOOL bEqualPassword = FALSE;
BOOL bEqualFlags = FALSE;
LPWSTR lpszXPassword = NULL;
LPWSTR lpszYPassword = NULL;
BOOL bReturnCode = FALSE;
HRESULT hr = S_OK;
if (x._lpszUserName && y._lpszUserName)
{
bEqualUser = !(wcscmp(x._lpszUserName, y._lpszUserName));
}
else if (!x._lpszUserName && !y._lpszUserName)
{
bEqualUser = TRUE;
}
hr = x.GetPassword(&lpszXPassword);
if (FAILED(hr))
{
goto error;
}
hr = y.GetPassword(&lpszYPassword);
if (FAILED(hr))
{
goto error;
}
if ((lpszXPassword && lpszYPassword))
{
bEqualPassword = !(wcscmp(lpszXPassword, lpszYPassword));
}
else if (!lpszXPassword && !lpszYPassword)
{
bEqualPassword = TRUE;
}
if (x._dwAuthFlags == y._dwAuthFlags)
{
bEqualFlags = TRUE;
}
if (bEqualUser && bEqualPassword && bEqualFlags)
{
bReturnCode = TRUE;
}
error:
if (lpszXPassword)
{
FreeADsStr(lpszXPassword);
}
if (lpszYPassword)
{
FreeADsStr(lpszYPassword);
}
return(bReturnCode);
}
BOOL
CCredentials::IsNullCredentials(
)
{
// The function will return true even if the flags are set
// this is because we want to try and get the default credentials
// even if the flags were set
if (!_lpszUserName && !_lpszPassword)
{
return(TRUE);
}
else
{
return(FALSE);
}
}
DWORD
CCredentials::GetAuthFlags()
{
return(_dwAuthFlags);
}
void
CCredentials::SetAuthFlags(
DWORD dwAuthFlags
)
{
_dwAuthFlags = dwAuthFlags;
}