From a26edc820852649c4103e897404a5259091feaea Mon Sep 17 00:00:00 2001
From: "SND\\EreTIk_cp" <SND\EreTIk_cp@9b283d60-5439-405e-af05-b73fd8c4d996>
Date: Wed, 8 May 2013 11:51:45 +0000
Subject: [PATCH] [0.2.x]  ~ refactoring: DIA findChildren  ~ refactoring:
 failed tests -> tail

git-svn-id: https://pykd.svn.codeplex.com/svn@83636 9b283d60-5439-405e-af05-b73fd8c4d996
---
 pykd/dia/diasymbol.cpp     | 188 +++++++++++++------------------------
 pykd/dia/diasymbol.h       |  10 ++
 test/scripts/moduletest.py |   2 +-
 test/scripts/regtest.py    |   4 +-
 test/scripts/typedvar.py   |   7 +-
 test/scripts/typeinfo.py   |   4 +-
 6 files changed, 85 insertions(+), 130 deletions(-)

diff --git a/pykd/dia/diasymbol.cpp b/pykd/dia/diasymbol.cpp
index 33df9f1..c72c904 100644
--- a/pykd/dia/diasymbol.cpp
+++ b/pykd/dia/diasymbol.cpp
@@ -55,37 +55,45 @@ SymbolPtr DiaSymbol::fromGlobalScope( IDiaSymbol *_symbol )
 
 //////////////////////////////////////////////////////////////////////////////////
 
-SymbolPtrList DiaSymbol::findChildren(ULONG symTag, const std::string &name)
+DiaSymbol::SelectedChilds DiaSymbol::selectChildren(
+    ULONG symtag,
+    LPCOLESTR name,
+    DWORD compareFlags
+)
 {
+    BOOST_ASSERT(symtag < SymTagMax);
+
+    const bool findAll = 
+        ( !name || !*name || (name[0] == L'*' && name[1] == L'\0') );
+
     DiaEnumSymbolsPtr symbols;
-    HRESULT hres;
 
-    const bool bFindAllNames = ( name.empty() || name == "*" );
-
-    if ( bFindAllNames )
-    {
-        hres = 
-            m_symbol->findChildren(
-                static_cast<enum ::SymTagEnum>(symTag),
-                NULL,
-                nsNone,
-                &symbols);
-    }
-    else
-    {
-        hres = 
-            m_symbol->findChildren(
-                static_cast<enum ::SymTagEnum>(symTag),
-                toWStr(name),
-                nsRegularExpression,
-                &symbols);
-    }
+    HRESULT hres = 
+        m_symbol->findChildren(
+            static_cast<enum ::SymTagEnum>(symtag),
+            ( findAll ? NULL : name ),
+            ( findAll ? nsNone : compareFlags),
+            &symbols);
 
     if (S_OK != hres)
         throw DiaException("IDiaSymbol::findChildren", hres);
 
-    SymbolPtrList childList;
+    LONG count = 0;
+    hres = symbols->get_Count(&count);
+    if (S_OK != hres)
+        throw DiaException("IDiaEnumSymbols::get_Count", hres);
 
+    return SelectedChilds(symbols, count);
+}
+
+//////////////////////////////////////////////////////////////////////////////////
+
+SymbolPtrList DiaSymbol::findChildren(ULONG symTag, const std::string &name)
+{
+    DiaEnumSymbolsPtr symbols = 
+        selectChildren(symTag, toWStr(name), nsRegularExpression).first;
+
+    SymbolPtrList childList;
     DiaSymbolPtr child;
     ULONG celt = 0;
     while ( SUCCEEDED(symbols->Next(1, &child, &celt)) && (celt == 1) )
@@ -115,50 +123,22 @@ ULONG DiaSymbol::getBitPosition()
 
 ULONG DiaSymbol::getChildCount(ULONG symTag)
 {
-    DiaEnumSymbolsPtr symbols;
-    HRESULT hres = 
-        m_symbol->findChildren(
-            static_cast<enum ::SymTagEnum>(symTag),
-            NULL,
-            nsfCaseSensitive | nsfUndecoratedName,
-            &symbols);
-    if (S_OK != hres)
-        throw DiaException("Call IDiaSymbol::findChildren", hres);
-
-    LONG count;
-    hres = symbols->get_Count(&count);
-    if (S_OK != hres)
-        throw DiaException("Call IDiaEnumSymbols::get_Count", hres);
-
-    return count;
+    return selectChildren(symTag).second;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
 
 SymbolPtr DiaSymbol::getChildByIndex(ULONG symTag, ULONG _index )
 {
-    DiaEnumSymbolsPtr symbols;
-    HRESULT hres = 
-        m_symbol->findChildren(
-            static_cast<enum ::SymTagEnum>(symTag),
-            NULL,
-            nsfCaseSensitive | nsfUndecoratedName,
-            &symbols);
-    if (S_OK != hres)
-        throw DiaException("Call IDiaSymbol::findChildren", hres);
+    SelectedChilds selected = selectChildren(symTag);
 
-    LONG count;
-    hres = symbols->get_Count(&count);
-    if (S_OK != hres)
-        throw DiaException("Call IDiaEnumSymbols::get_Count", hres);
-
-    if (LONG(_index) >= count)
+    if (LONG(_index) >= selected.second)
     {
         throw PyException( PyExc_IndexError, "Index out of range");
     }
 
     DiaSymbolPtr child;
-    hres = symbols->Item(_index, &child);
+    HRESULT hres = selected.first->Item(_index, &child);
     if (S_OK != hres)
         throw DiaException("Call IDiaEnumSymbols::Item", hres);
 
@@ -167,33 +147,12 @@ SymbolPtr DiaSymbol::getChildByIndex(ULONG symTag, ULONG _index )
 
 ////////////////////////////////////////////////////////////////////////////////
 
-SymbolPtr DiaSymbol::getChildByName(const std::string &name )
+SymbolPtr DiaSymbol::getChildByName(const std::string &name)
 {
-    DiaEnumSymbolsPtr symbols;
-    HRESULT hres = 
-        m_symbol->findChildren(
-            ::SymTagNull,
-            toWStr(name),
-            nsfCaseSensitive,
-            &symbols);
+    SelectedChilds selected = selectChildren(::SymTagNull, toWStr(name), nsCaseSensitive);
 
-    LONG count;
-    hres = symbols->get_Count(&count);
-    if (S_OK != hres)
-        throw DiaException("Call IDiaEnumSymbols::get_Count", hres);
-
-    if (count > 0)
-    {
-        if (count > 1)
-            throw SymbolException(name + "is ambiguous");
-
-        DiaSymbolPtr child;
-        hres = symbols->Item(0, &child);
-        if (S_OK != hres)
-            throw DiaException("Call IDiaEnumSymbols::Item", hres);
-
-        return SymbolPtr( new DiaSymbol(child, m_machineType) );
-    }
+    if (selected.second > 0)
+        return getChildBySelected(selected, name);
 
     if (m_publicSymbols)
     {
@@ -208,67 +167,50 @@ SymbolPtr DiaSymbol::getChildByName(const std::string &name )
     std::string underscoreName;
     underscoreName += '_';
     underscoreName += name;
-    symbols = 0;
 
-    hres = 
-        m_symbol->findChildren(
+    selected = 
+        selectChildren(
             ::SymTagNull,
             toWStr(underscoreName),
-            nsfCaseSensitive | nsfUndecoratedName,
-            &symbols);
+            nsfCaseSensitive | nsfUndecoratedName);
 
-    hres = symbols->get_Count(&count);
-    if (S_OK != hres)
-        throw DiaException("Call IDiaEnumSymbols::get_Count", hres);
-
-    if (count >0 )
-    {
-        DiaSymbolPtr child;
-        hres = symbols->Item(0, &child);
-        if (S_OK != hres)
-            throw DiaException("Call IDiaEnumSymbols::Item", hres);
-
-        return SymbolPtr( new DiaSymbol(child, m_machineType) );
-    }
+    if (selected.second > 0)
+        return getChildBySelected(selected, name);
 
     // _���@�����
-    std::string     pattern = "_";
+    std::string pattern = "_";
     pattern += name;
     pattern += "@*";
-    symbols = 0;
 
-    hres = 
-        m_symbol->findChildren(
+    selected = 
+        selectChildren(
             ::SymTagNull,
             toWStr(pattern),
-            nsfRegularExpression | nsfCaseSensitive | nsfUndecoratedName,
-            &symbols);
+            nsfRegularExpression | nsfCaseSensitive | nsfUndecoratedName);
 
-    if (S_OK != hres)
-        throw DiaException("Call IDiaSymbol::findChildren", hres);
-
-    hres = symbols->get_Count(&count);
-    if (S_OK != hres)
-        throw DiaException("Call IDiaEnumSymbols::get_Count", hres);
-
-    if (count == 0)
-         throw DiaException( name + " not found");
-
-    if (count >0 )
-    {
-        DiaSymbolPtr child;
-        hres = symbols->Item(0, &child);
-        if (S_OK != hres)
-            throw DiaException("Call IDiaEnumSymbols::Item", hres);
-
-        return SymbolPtr( new DiaSymbol(child, m_machineType) );
-    }
+    if (selected.second > 0)
+        return getChildBySelected(selected, name);
 
     throw DiaException(name + " is not found");
 }
 
 //////////////////////////////////////////////////////////////////////////////
 
+SymbolPtr DiaSymbol::getChildBySelected(const SelectedChilds &selected, const std::string &name)
+{
+    if (selected.second > 1)
+        throw SymbolException(name + "is ambiguous");
+
+    DiaSymbolPtr child;
+    HRESULT hres = selected.first->Item(0, &child);
+    if (S_OK != hres)
+        throw DiaException("Call IDiaEnumSymbols::Item", hres);
+
+    return SymbolPtr( new DiaSymbol(child, m_machineType) );
+}
+
+//////////////////////////////////////////////////////////////////////////////
+
 ULONG DiaSymbol::getCount()
 {
     return callSymbol(get_count);
diff --git a/pykd/dia/diasymbol.h b/pykd/dia/diasymbol.h
index 6de4268..b14f396 100644
--- a/pykd/dia/diasymbol.h
+++ b/pykd/dia/diasymbol.h
@@ -104,6 +104,16 @@ protected:
     static const DiaRegToRegRelativeBase &regToRegRelativeI386;
     ULONG getRegRealativeIdImpl(const DiaRegToRegRelativeBase &DiaRegToRegRelative);
 
+    // IDiaSymbol::findChildren/IDiaEnumSymbols::get_Count wrapper
+    typedef std::pair< DiaEnumSymbolsPtr, LONG > SelectedChilds;
+    SelectedChilds selectChildren(
+        ULONG symtag,
+        LPCOLESTR name = NULL,
+        DWORD compareFlags = nsNone
+    );
+
+    SymbolPtr getChildBySelected(const SelectedChilds &selected, const std::string &name);
+
     template <typename TRet>
     TRet callSymbolT(
         HRESULT(STDMETHODCALLTYPE IDiaSymbol::*method)(TRet *),
diff --git a/test/scripts/moduletest.py b/test/scripts/moduletest.py
index a457d13..1b960b2 100644
--- a/test/scripts/moduletest.py
+++ b/test/scripts/moduletest.py
@@ -91,7 +91,7 @@ class ModuleTest( unittest.TestCase ):
         self.assertEqual( 2, len(lst) )
         lst = target.module.enumSymbols( "g_const*Value")
         self.assertEqual( 2, len(lst) )
-        lst = target.module.enumSymbols( "*FuncWithName*")
+        lst = target.module.enumSymbols( "*FuncWithName?")
         self.assertEqual( 3, len(lst) )
         lst = target.module.enumSymbols( "*virtFunc*") 
         self.assertNotEqual( 0, len(lst) )
diff --git a/test/scripts/regtest.py b/test/scripts/regtest.py
index 956911c..0398598 100644
--- a/test/scripts/regtest.py
+++ b/test/scripts/regtest.py
@@ -43,8 +43,8 @@ class CpuRegTest( unittest.TestCase ):
     def testFloatRegister(self):
         "TODO: support float point regsiters"
         self.assertRaises( pykd.BaseException, pykd.reg, "st0" )
-        
+
     def testMmxRegister(self):
         "TODO: support MMX regsiters"
         self.assertRaises( pykd.BaseException, pykd.reg, "mmx0" )
-        
+
diff --git a/test/scripts/typedvar.py b/test/scripts/typedvar.py
index 640b046..a11edeb 100644
--- a/test/scripts/typedvar.py
+++ b/test/scripts/typedvar.py
@@ -117,10 +117,13 @@ class TypedVarTest( unittest.TestCase ):
         except IndexError:
             self.assertTrue(True)
 
-    def testArrayFieldSlice(self): 
+    def testArrayFieldSlice(self):
+        tv = target.module.typedVar( "g_struct3" )
+        self.assertEqual( [ 0, 2 ], tv.m_arrayField[0:2] )
+
+    def testArrayFieldSliceNegative(self):
         tv = target.module.typedVar( "g_struct3" )
         self.assertEqual( 2, tv.m_arrayField[-1] )
-        self.assertEqual( [ 0, 2 ], tv.m_arrayField[0:2] )
 
     def testGlobalVar(self):
         self.assertEqual( 4, target.module.typedVar( "g_ulongValue" ) )
diff --git a/test/scripts/typeinfo.py b/test/scripts/typeinfo.py
index 43b25c9..67bd6f4 100644
--- a/test/scripts/typeinfo.py
+++ b/test/scripts/typeinfo.py
@@ -18,12 +18,12 @@ class TypeInfoTest( unittest.TestCase ):
     def testCreateByName( self ):
         """ creating typeInfo by the type name """
         self.assertEqual( "Int4B*", target.module.type("Int4B*").name() )
-        self.assertEqual( "Int4B*", pykd.typeInfo("Int4B*").name() )
         self.assertEqual( "structTest", target.module.type( "structTest" ).name() )
         self.assertEqual( "structTest**", target.module.type( "structTest**" ).name() )
         self.assertEqual( "Int4B[2][3]", target.module.type("Int4B[2][3]").name() )
         self.assertEqual( "Int4B(*[4])[2][3]", target.module.type("Int4B(*[4])[2][3]").name() )
         self.assertEqual( "Int4B(*)[2][3]", target.module.type("Int4B((*))[2][3]").name() )
+        self.assertEqual( "Int4B*", pykd.typeInfo("Int4B*").name() )
 
     def testCreateBySymbol(self):
         """ creating typeInfo by the symbol name """
@@ -203,7 +203,7 @@ class TypeInfoTest( unittest.TestCase ):
         self.assertNotEqual( 0, ti.staticOffset("m_stdstr") )
         if not ti.staticOffset("m_staticConst"):
             self.assertFalse( "MS DIA bug: https://connect.microsoft.com/VisualStudio/feedback/details/737430" )
-            
+
     def testVfnTable(self):
         ti = pykd.typeInfo( "g_classChild" )
         self.assertTrue( hasattr( ti, "__VFN_table" ) )