2020-09-30 17:12:32 +02:00

520 lines
11 KiB
C++

/*
* clsfact.cpp - IClassFactory implementation.
*/
/* Headers
*/
#include "project.hpp"
#pragma hdrstop
#include "clsfact.h"
#include "ftps.hpp"
#include "inetcpl.h"
#include "inetps.hpp"
/* Types
*/
// callback function used by ClassFactory::ClassFactory()
typedef PIUnknown (*NEWOBJECTPROC)(OBJECTDESTROYEDPROC);
DECLARE_STANDARD_TYPES(NEWOBJECTPROC);
// description of class supported by DllGetClassObject()
typedef struct classconstructor
{
PCCLSID pcclsid;
NEWOBJECTPROC NewObject;
}
CLASSCONSTRUCTOR;
DECLARE_STANDARD_TYPES(CLASSCONSTRUCTOR);
/* Classes
*/
// object class factory
class ClassFactory : public RefCount,
public IClassFactory
{
private:
NEWOBJECTPROC m_NewObject;
public:
ClassFactory(NEWOBJECTPROC NewObject, OBJECTDESTROYEDPROC ObjectDestroyed);
~ClassFactory(void);
// IClassFactory methods
HRESULT STDMETHODCALLTYPE CreateInstance(PIUnknown piunkOuter, REFIID riid, PVOID *ppvObject);
HRESULT STDMETHODCALLTYPE LockServer(BOOL bLock);
// IUnknown methods
ULONG STDMETHODCALLTYPE AddRef(void);
ULONG STDMETHODCALLTYPE Release(void);
HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, PVOID *ppvObj);
// friends
#ifdef DEBUG
friend BOOL IsValidPCClassFactory(const ClassFactory *pcurlcf);
#endif
};
DECLARE_STANDARD_TYPES(ClassFactory);
/* Module Prototypes
*/
PRIVATE_CODE PIUnknown NewInternetShortcut(OBJECTDESTROYEDPROC ObjectDestroyed);
PRIVATE_CODE PIUnknown NewMIMEHook(OBJECTDESTROYEDPROC ObjectDestroyed);
PRIVATE_CODE PIUnknown NewInternet(OBJECTDESTROYEDPROC ObjectDestroyed);
/* Module Constants
*/
#pragma data_seg(DATA_SEG_READ_ONLY)
PRIVATE_DATA CCLASSCONSTRUCTOR s_cclscnstr[] =
{
{ &CLSID_InternetShortcut, &NewInternetShortcut },
{ &CLSID_MIMEFileTypesPropSheetHook, &NewMIMEHook },
{ &CLSID_Internet, &NewInternet },
};
#pragma data_seg()
/* Module Variables
*/
#pragma data_seg(DATA_SEG_PER_INSTANCE)
// DLL reference count == number of class factories +
// number of URLs +
// LockServer() count
PRIVATE_DATA ULONG s_ulcDLLRef = 0;
#pragma data_seg()
/***** Private Functions *****/
PRIVATE_CODE HRESULT GetClassConstructor(REFCLSID rclsid,
PNEWOBJECTPROC pNewObject)
{
HRESULT hr = CLASS_E_CLASSNOTAVAILABLE;
UINT u;
ASSERT(IsValidREFCLSID(rclsid));
ASSERT(IS_VALID_WRITE_PTR(pNewObject, NEWOBJECTPROC));
*pNewObject = NULL;
for (u = 0; u < ARRAY_ELEMENTS(s_cclscnstr); u++)
{
if (rclsid == *(s_cclscnstr[u].pcclsid))
{
*pNewObject = s_cclscnstr[u].NewObject;
hr = S_OK;
}
}
ASSERT((hr == S_OK &&
IS_VALID_CODE_PTR(*pNewObject, NEWOBJECTPROC)) ||
(hr == CLASS_E_CLASSNOTAVAILABLE &&
! *pNewObject));
return(hr);
}
PRIVATE_CODE void STDMETHODCALLTYPE DLLObjectDestroyed(void)
{
TRACE_OUT(("DLLObjectDestroyed(): Object destroyed."));
DLLRelease();
}
PRIVATE_CODE PIUnknown NewInternetShortcut(OBJECTDESTROYEDPROC ObjectDestroyed)
{
ASSERT(! ObjectDestroyed ||
IS_VALID_CODE_PTR(ObjectDestroyed, OBJECTDESTROYEDPROC));
TRACE_OUT(("NewInternetShortcut(): Creating a new InternetShortcut."));
return((PIUnknown)(PIUniformResourceLocator)new(InternetShortcut(ObjectDestroyed)));
}
PRIVATE_CODE PIUnknown NewMIMEHook(OBJECTDESTROYEDPROC ObjectDestroyed)
{
ASSERT(! ObjectDestroyed ||
IS_VALID_CODE_PTR(ObjectDestroyed, OBJECTDESTROYEDPROC));
TRACE_OUT(("NewMIMEHook(): Creating a new MIMEHook."));
return((PIUnknown)(PIShellPropSheetExt)new(MIMEHook(ObjectDestroyed)));
}
PRIVATE_CODE PIUnknown NewInternet(OBJECTDESTROYEDPROC ObjectDestroyed)
{
ASSERT(! ObjectDestroyed ||
IS_VALID_CODE_PTR(ObjectDestroyed, OBJECTDESTROYEDPROC));
TRACE_OUT(("NewInternet(): Creating a new Internet."));
return((PIUnknown)(PIShellPropSheetExt)new(Internet(ObjectDestroyed)));
}
#ifdef DEBUG
PRIVATE_CODE BOOL IsValidPCClassFactory(PCClassFactory pccf)
{
return(IS_VALID_READ_PTR(pccf, CClassFactory) &&
IS_VALID_CODE_PTR(pccf->m_NewObject, NEWOBJECTPROC) &&
IS_VALID_STRUCT_PTR((PCRefCount)pccf, CRefCount) &&
IS_VALID_INTERFACE_PTR((PCIClassFactory)pccf, IClassFactory));
}
#endif
/****** Public Functions *****/
PUBLIC_CODE ULONG DLLAddRef(void)
{
ULONG ulcRef;
ASSERT(s_ulcDLLRef < ULONG_MAX);
ulcRef = ++s_ulcDLLRef;
TRACE_OUT(("DLLAddRef(): DLL reference count is now %lu.",
ulcRef));
return(ulcRef);
}
PUBLIC_CODE ULONG DLLRelease(void)
{
ULONG ulcRef;
if (EVAL(s_ulcDLLRef > 0))
s_ulcDLLRef--;
ulcRef = s_ulcDLLRef;
TRACE_OUT(("DLLRelease(): DLL reference count is now %lu.",
ulcRef));
return(ulcRef);
}
PUBLIC_CODE PULONG GetDLLRefCountPtr(void)
{
return(&s_ulcDLLRef);
}
/*** Methods ***/
ClassFactory::ClassFactory(NEWOBJECTPROC NewObject,
OBJECTDESTROYEDPROC ObjectDestroyed) :
RefCount(ObjectDestroyed)
{
DebugEntry(ClassFactory::ClassFactory);
// Don't validate this until after construction.
ASSERT(IS_VALID_CODE_PTR(NewObject, NEWOBJECTPROC));
m_NewObject = NewObject;
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
DebugExitVOID(ClassFactory::ClassFactory);
return;
}
ClassFactory::~ClassFactory(void)
{
DebugEntry(ClassFactory::~ClassFactory);
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
m_NewObject = NULL;
// Don't validate this after destruction.
DebugExitVOID(ClassFactory::~ClassFactory);
return;
}
ULONG STDMETHODCALLTYPE ClassFactory::AddRef(void)
{
ULONG ulcRef;
DebugEntry(ClassFactory::AddRef);
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
ulcRef = RefCount::AddRef();
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
DebugExitULONG(ClassFactory::AddRef, ulcRef);
return(ulcRef);
}
ULONG STDMETHODCALLTYPE ClassFactory::Release(void)
{
ULONG ulcRef;
DebugEntry(ClassFactory::Release);
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
ulcRef = RefCount::Release();
DebugExitULONG(ClassFactory::Release, ulcRef);
return(ulcRef);
}
HRESULT STDMETHODCALLTYPE ClassFactory::QueryInterface(REFIID riid,
PVOID *ppvObject)
{
HRESULT hr = S_OK;
DebugEntry(ClassFactory::QueryInterface);
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
ASSERT(IsValidREFIID(riid));
ASSERT(IS_VALID_WRITE_PTR(ppvObject, PVOID));
if (riid == IID_IClassFactory)
{
*ppvObject = (PIClassFactory)this;
ASSERT(IS_VALID_INTERFACE_PTR((PIClassFactory)*ppvObject, IClassFactory));
TRACE_OUT(("ClassFactory::QueryInterface(): Returning IClassFactory."));
}
else if (riid == IID_IUnknown)
{
*ppvObject = (PIUnknown)this;
ASSERT(IS_VALID_INTERFACE_PTR((PIUnknown)*ppvObject, IUnknown));
TRACE_OUT(("ClassFactory::QueryInterface(): Returning IUnknown."));
}
else
{
*ppvObject = NULL;
hr = E_NOINTERFACE;
TRACE_OUT(("ClassFactory::QueryInterface(): Called on unknown interface."));
}
if (hr == S_OK)
AddRef();
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
ASSERT(FAILED(hr) ||
IS_VALID_INTERFACE_PTR(*ppvObject, INTERFACE));
DebugExitHRESULT(ClassFactory::QueryInterface, hr);
return(hr);
}
HRESULT STDMETHODCALLTYPE ClassFactory::CreateInstance(PIUnknown piunkOuter,
REFIID riid,
PVOID *ppvObject)
{
HRESULT hr;
DebugEntry(ClassFactory::CreateInstance);
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
ASSERT(! piunkOuter ||
IS_VALID_INTERFACE_PTR(piunkOuter, IUnknown));
ASSERT(IsValidREFIID(riid));
ASSERT(IS_VALID_WRITE_PTR(ppvObject, PVOID));
*ppvObject = NULL;
if (! piunkOuter)
{
PIUnknown piunk;
piunk = (*m_NewObject)(&DLLObjectDestroyed);
if (piunk)
{
DLLAddRef();
hr = piunk->QueryInterface(riid, ppvObject);
// N.b., the Release() method will destroy the object if the
// QueryInterface() method failed.
piunk->Release();
}
else
hr = E_OUTOFMEMORY;
}
else
{
hr = CLASS_E_NOAGGREGATION;
WARNING_OUT(("ClassFactory::CreateInstance(): Aggregation not supported."));
}
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
ASSERT(FAILED(hr) ||
IS_VALID_INTERFACE_PTR(*ppvObject, INTERFACE));
DebugExitHRESULT(ClassFactory::CreateInstance, hr);
return(hr);
}
HRESULT STDMETHODCALLTYPE ClassFactory::LockServer(BOOL bLock)
{
HRESULT hr;
DebugEntry(ClassFactory::LockServer);
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
// bLock may be any value.
if (bLock)
DLLAddRef();
else
DLLRelease();
hr = S_OK;
ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
DebugExitHRESULT(ClassFactory::LockServer, hr);
return(hr);
}
/***** Exported Functions ****/
STDAPI DllGetClassObject(REFCLSID rclsid, REFIID riid, PVOID *ppvObject)
{
HRESULT hr = S_OK;
NEWOBJECTPROC NewObject;
DebugEntry(DllGetClassObject);
ASSERT(IsValidREFCLSID(rclsid));
ASSERT(IsValidREFIID(riid));
ASSERT(IS_VALID_WRITE_PTR(ppvObject, PVOID));
*ppvObject = NULL;
hr = GetClassConstructor(rclsid, &NewObject);
if (hr == S_OK)
{
if (riid == IID_IUnknown ||
riid == IID_IClassFactory)
{
PClassFactory pcf;
pcf = new(ClassFactory(NewObject, &DLLObjectDestroyed));
if (pcf)
{
if (riid == IID_IClassFactory)
{
*ppvObject = (PIClassFactory)pcf;
ASSERT(IS_VALID_INTERFACE_PTR((PIClassFactory)*ppvObject, IClassFactory));
TRACE_OUT(("DllGetClassObject(): Returning IClassFactory."));
}
else
{
ASSERT(riid == IID_IUnknown);
*ppvObject = (PIUnknown)pcf;
ASSERT(IS_VALID_INTERFACE_PTR((PIUnknown)*ppvObject, IUnknown));
TRACE_OUT(("DllGetClassObject(): Returning IUnknown."));
}
DLLAddRef();
hr = S_OK;
TRACE_OUT(("DllGetClassObject(): Created a new class factory."));
}
else
hr = E_OUTOFMEMORY;
}
else
{
WARNING_OUT(("DllGetClassObject(): Called on unknown interface."));
hr = E_NOINTERFACE;
}
}
else
WARNING_OUT(("DllGetClassObject(): Called on unknown class."));
ASSERT(FAILED(hr) ||
IS_VALID_INTERFACE_PTR(*ppvObject, INTERFACE));
DebugExitHRESULT(DllGetClassObject, hr);
return(hr);
}
STDAPI DllCanUnloadNow(void)
{
HRESULT hr;
DebugEntry(DllCanUnloadNow);
hr = (s_ulcDLLRef > 0) ? S_FALSE : S_OK;
if (hr == S_OK)
hr = InternetCPLCanUnloadNow();
TRACE_OUT(("DllCanUnloadNow(): DLL reference count is %lu.",
s_ulcDLLRef));
DebugExitHRESULT(DllCanUnloadNow, hr);
return(hr);
}