diff --git a/pykd/dia/diasymbol.cpp b/pykd/dia/diasymbol.cpp index 33df9f1..c72c904 100644 --- a/pykd/dia/diasymbol.cpp +++ b/pykd/dia/diasymbol.cpp @@ -55,37 +55,45 @@ SymbolPtr DiaSymbol::fromGlobalScope( IDiaSymbol *_symbol ) ////////////////////////////////////////////////////////////////////////////////// -SymbolPtrList DiaSymbol::findChildren(ULONG symTag, const std::string &name) +DiaSymbol::SelectedChilds DiaSymbol::selectChildren( + ULONG symtag, + LPCOLESTR name, + DWORD compareFlags +) { + BOOST_ASSERT(symtag < SymTagMax); + + const bool findAll = + ( !name || !*name || (name[0] == L'*' && name[1] == L'\0') ); + DiaEnumSymbolsPtr symbols; - HRESULT hres; - const bool bFindAllNames = ( name.empty() || name == "*" ); - - if ( bFindAllNames ) - { - hres = - m_symbol->findChildren( - static_cast(symTag), - NULL, - nsNone, - &symbols); - } - else - { - hres = - m_symbol->findChildren( - static_cast(symTag), - toWStr(name), - nsRegularExpression, - &symbols); - } + HRESULT hres = + m_symbol->findChildren( + static_cast(symtag), + ( findAll ? NULL : name ), + ( findAll ? nsNone : compareFlags), + &symbols); if (S_OK != hres) throw DiaException("IDiaSymbol::findChildren", hres); - SymbolPtrList childList; + LONG count = 0; + hres = symbols->get_Count(&count); + if (S_OK != hres) + throw DiaException("IDiaEnumSymbols::get_Count", hres); + return SelectedChilds(symbols, count); +} + +////////////////////////////////////////////////////////////////////////////////// + +SymbolPtrList DiaSymbol::findChildren(ULONG symTag, const std::string &name) +{ + DiaEnumSymbolsPtr symbols = + selectChildren(symTag, toWStr(name), nsRegularExpression).first; + + SymbolPtrList childList; DiaSymbolPtr child; ULONG celt = 0; while ( SUCCEEDED(symbols->Next(1, &child, &celt)) && (celt == 1) ) @@ -115,50 +123,22 @@ ULONG DiaSymbol::getBitPosition() ULONG DiaSymbol::getChildCount(ULONG symTag) { - DiaEnumSymbolsPtr symbols; - HRESULT hres = - m_symbol->findChildren( - static_cast(symTag), - NULL, - nsfCaseSensitive | nsfUndecoratedName, - &symbols); - if (S_OK != hres) - throw DiaException("Call IDiaSymbol::findChildren", hres); - - LONG count; - hres = symbols->get_Count(&count); - if (S_OK != hres) - throw DiaException("Call IDiaEnumSymbols::get_Count", hres); - - return count; + return selectChildren(symTag).second; } //////////////////////////////////////////////////////////////////////////////// SymbolPtr DiaSymbol::getChildByIndex(ULONG symTag, ULONG _index ) { - DiaEnumSymbolsPtr symbols; - HRESULT hres = - m_symbol->findChildren( - static_cast(symTag), - NULL, - nsfCaseSensitive | nsfUndecoratedName, - &symbols); - if (S_OK != hres) - throw DiaException("Call IDiaSymbol::findChildren", hres); + SelectedChilds selected = selectChildren(symTag); - LONG count; - hres = symbols->get_Count(&count); - if (S_OK != hres) - throw DiaException("Call IDiaEnumSymbols::get_Count", hres); - - if (LONG(_index) >= count) + if (LONG(_index) >= selected.second) { throw PyException( PyExc_IndexError, "Index out of range"); } DiaSymbolPtr child; - hres = symbols->Item(_index, &child); + HRESULT hres = selected.first->Item(_index, &child); if (S_OK != hres) throw DiaException("Call IDiaEnumSymbols::Item", hres); @@ -167,33 +147,12 @@ SymbolPtr DiaSymbol::getChildByIndex(ULONG symTag, ULONG _index ) //////////////////////////////////////////////////////////////////////////////// -SymbolPtr DiaSymbol::getChildByName(const std::string &name ) +SymbolPtr DiaSymbol::getChildByName(const std::string &name) { - DiaEnumSymbolsPtr symbols; - HRESULT hres = - m_symbol->findChildren( - ::SymTagNull, - toWStr(name), - nsfCaseSensitive, - &symbols); + SelectedChilds selected = selectChildren(::SymTagNull, toWStr(name), nsCaseSensitive); - LONG count; - hres = symbols->get_Count(&count); - if (S_OK != hres) - throw DiaException("Call IDiaEnumSymbols::get_Count", hres); - - if (count > 0) - { - if (count > 1) - throw SymbolException(name + "is ambiguous"); - - DiaSymbolPtr child; - hres = symbols->Item(0, &child); - if (S_OK != hres) - throw DiaException("Call IDiaEnumSymbols::Item", hres); - - return SymbolPtr( new DiaSymbol(child, m_machineType) ); - } + if (selected.second > 0) + return getChildBySelected(selected, name); if (m_publicSymbols) { @@ -208,67 +167,50 @@ SymbolPtr DiaSymbol::getChildByName(const std::string &name ) std::string underscoreName; underscoreName += '_'; underscoreName += name; - symbols = 0; - hres = - m_symbol->findChildren( + selected = + selectChildren( ::SymTagNull, toWStr(underscoreName), - nsfCaseSensitive | nsfUndecoratedName, - &symbols); + nsfCaseSensitive | nsfUndecoratedName); - hres = symbols->get_Count(&count); - if (S_OK != hres) - throw DiaException("Call IDiaEnumSymbols::get_Count", hres); - - if (count >0 ) - { - DiaSymbolPtr child; - hres = symbols->Item(0, &child); - if (S_OK != hres) - throw DiaException("Call IDiaEnumSymbols::Item", hres); - - return SymbolPtr( new DiaSymbol(child, m_machineType) ); - } + if (selected.second > 0) + return getChildBySelected(selected, name); // _имя@парам - std::string pattern = "_"; + std::string pattern = "_"; pattern += name; pattern += "@*"; - symbols = 0; - hres = - m_symbol->findChildren( + selected = + selectChildren( ::SymTagNull, toWStr(pattern), - nsfRegularExpression | nsfCaseSensitive | nsfUndecoratedName, - &symbols); + nsfRegularExpression | nsfCaseSensitive | nsfUndecoratedName); - if (S_OK != hres) - throw DiaException("Call IDiaSymbol::findChildren", hres); - - hres = symbols->get_Count(&count); - if (S_OK != hres) - throw DiaException("Call IDiaEnumSymbols::get_Count", hres); - - if (count == 0) - throw DiaException( name + " not found"); - - if (count >0 ) - { - DiaSymbolPtr child; - hres = symbols->Item(0, &child); - if (S_OK != hres) - throw DiaException("Call IDiaEnumSymbols::Item", hres); - - return SymbolPtr( new DiaSymbol(child, m_machineType) ); - } + if (selected.second > 0) + return getChildBySelected(selected, name); throw DiaException(name + " is not found"); } ////////////////////////////////////////////////////////////////////////////// +SymbolPtr DiaSymbol::getChildBySelected(const SelectedChilds &selected, const std::string &name) +{ + if (selected.second > 1) + throw SymbolException(name + "is ambiguous"); + + DiaSymbolPtr child; + HRESULT hres = selected.first->Item(0, &child); + if (S_OK != hres) + throw DiaException("Call IDiaEnumSymbols::Item", hres); + + return SymbolPtr( new DiaSymbol(child, m_machineType) ); +} + +////////////////////////////////////////////////////////////////////////////// + ULONG DiaSymbol::getCount() { return callSymbol(get_count); diff --git a/pykd/dia/diasymbol.h b/pykd/dia/diasymbol.h index 6de4268..b14f396 100644 --- a/pykd/dia/diasymbol.h +++ b/pykd/dia/diasymbol.h @@ -104,6 +104,16 @@ protected: static const DiaRegToRegRelativeBase ®ToRegRelativeI386; ULONG getRegRealativeIdImpl(const DiaRegToRegRelativeBase &DiaRegToRegRelative); + // IDiaSymbol::findChildren/IDiaEnumSymbols::get_Count wrapper + typedef std::pair< DiaEnumSymbolsPtr, LONG > SelectedChilds; + SelectedChilds selectChildren( + ULONG symtag, + LPCOLESTR name = NULL, + DWORD compareFlags = nsNone + ); + + SymbolPtr getChildBySelected(const SelectedChilds &selected, const std::string &name); + template TRet callSymbolT( HRESULT(STDMETHODCALLTYPE IDiaSymbol::*method)(TRet *), diff --git a/test/scripts/moduletest.py b/test/scripts/moduletest.py index a457d13..1b960b2 100644 --- a/test/scripts/moduletest.py +++ b/test/scripts/moduletest.py @@ -91,7 +91,7 @@ class ModuleTest( unittest.TestCase ): self.assertEqual( 2, len(lst) ) lst = target.module.enumSymbols( "g_const*Value") self.assertEqual( 2, len(lst) ) - lst = target.module.enumSymbols( "*FuncWithName*") + lst = target.module.enumSymbols( "*FuncWithName?") self.assertEqual( 3, len(lst) ) lst = target.module.enumSymbols( "*virtFunc*") self.assertNotEqual( 0, len(lst) ) diff --git a/test/scripts/regtest.py b/test/scripts/regtest.py index 956911c..0398598 100644 --- a/test/scripts/regtest.py +++ b/test/scripts/regtest.py @@ -43,8 +43,8 @@ class CpuRegTest( unittest.TestCase ): def testFloatRegister(self): "TODO: support float point regsiters" self.assertRaises( pykd.BaseException, pykd.reg, "st0" ) - + def testMmxRegister(self): "TODO: support MMX regsiters" self.assertRaises( pykd.BaseException, pykd.reg, "mmx0" ) - + diff --git a/test/scripts/typedvar.py b/test/scripts/typedvar.py index 640b046..a11edeb 100644 --- a/test/scripts/typedvar.py +++ b/test/scripts/typedvar.py @@ -117,10 +117,13 @@ class TypedVarTest( unittest.TestCase ): except IndexError: self.assertTrue(True) - def testArrayFieldSlice(self): + def testArrayFieldSlice(self): + tv = target.module.typedVar( "g_struct3" ) + self.assertEqual( [ 0, 2 ], tv.m_arrayField[0:2] ) + + def testArrayFieldSliceNegative(self): tv = target.module.typedVar( "g_struct3" ) self.assertEqual( 2, tv.m_arrayField[-1] ) - self.assertEqual( [ 0, 2 ], tv.m_arrayField[0:2] ) def testGlobalVar(self): self.assertEqual( 4, target.module.typedVar( "g_ulongValue" ) ) diff --git a/test/scripts/typeinfo.py b/test/scripts/typeinfo.py index 43b25c9..67bd6f4 100644 --- a/test/scripts/typeinfo.py +++ b/test/scripts/typeinfo.py @@ -18,12 +18,12 @@ class TypeInfoTest( unittest.TestCase ): def testCreateByName( self ): """ creating typeInfo by the type name """ self.assertEqual( "Int4B*", target.module.type("Int4B*").name() ) - self.assertEqual( "Int4B*", pykd.typeInfo("Int4B*").name() ) self.assertEqual( "structTest", target.module.type( "structTest" ).name() ) self.assertEqual( "structTest**", target.module.type( "structTest**" ).name() ) self.assertEqual( "Int4B[2][3]", target.module.type("Int4B[2][3]").name() ) self.assertEqual( "Int4B(*[4])[2][3]", target.module.type("Int4B(*[4])[2][3]").name() ) self.assertEqual( "Int4B(*)[2][3]", target.module.type("Int4B((*))[2][3]").name() ) + self.assertEqual( "Int4B*", pykd.typeInfo("Int4B*").name() ) def testCreateBySymbol(self): """ creating typeInfo by the symbol name """ @@ -203,7 +203,7 @@ class TypeInfoTest( unittest.TestCase ): self.assertNotEqual( 0, ti.staticOffset("m_stdstr") ) if not ti.staticOffset("m_staticConst"): self.assertFalse( "MS DIA bug: https://connect.microsoft.com/VisualStudio/feedback/details/737430" ) - + def testVfnTable(self): ti = pykd.typeInfo( "g_classChild" ) self.assertTrue( hasattr( ti, "__VFN_table" ) )