diff --git a/include/core/SkRefCnt.h b/include/core/SkRefCnt.h index 7e325e0002..956397a832 100644 --- a/include/core/SkRefCnt.h +++ b/include/core/SkRefCnt.h @@ -147,5 +147,49 @@ template static inline void SkSafeUnref(T* obj) { } } +template +class SkRefPtr { +public: + SkRefPtr() : fObj(NULL) {} + SkRefPtr(T* obj) : fObj(obj) { SkSafeRef(fObj); } + SkRefPtr(const SkRefPtr& o) : fObj(o.fObj) { SkSafeRef(fObj); } + ~SkRefPtr() { SkSafeUnref(fObj); } + + SkRefPtr& operator=(const SkRefPtr& rp) { + SkRefCnt_SafeAssign(fObj, rp.fObj); + return *this; + } + SkRefPtr& operator=(T* obj) { + SkRefCnt_SafeAssign(fObj, obj); + return *this; + } + + bool operator==(const SkRefPtr& rp) const { return fObj == rp.fObj; } + bool operator==(const T* obj) const { return fObj == obj; } + bool operator!=(const SkRefPtr& rp) const { return fObj != rp.fObj; } + bool operator!=(const T* obj) const { return fObj != obj; } + + T* get() const { return fObj; } + T& operator*() const { return *fObj; } + T* operator->() const { return fObj; } + bool operator!() const { return !fObj; } + + typedef T* SkRefPtr::*unspecified_bool_type; + operator unspecified_bool_type() const { return fObj ? &SkRefPtr::fObj : NULL; } + +private: + T* fObj; +}; + +template +inline bool operator==(T* obj, const SkRefPtr& rp) { + return obj == rp.get(); +} + +template +inline bool operator!=(T* obj, const SkRefPtr& rp) { + return obj != rp.get(); +} + #endif diff --git a/tests/UtilsTest.cpp b/tests/UtilsTest.cpp index 8a8319c74b..8ec063e77d 100644 --- a/tests/UtilsTest.cpp +++ b/tests/UtilsTest.cpp @@ -1,9 +1,50 @@ #include "Test.h" #include "SkRandom.h" +#include "SkRefCnt.h" #include "SkTSearch.h" #include "SkTSort.h" #include "SkUtils.h" +class RefClass : public SkRefCnt { +public: + RefClass(int n) : fN(n) {} + int get() const { return fN; } + +private: + int fN; +}; + +static void test_refptr(skiatest::Reporter* reporter) { + RefClass* r0 = new RefClass(0); + + SkRefPtr rc0; + REPORTER_ASSERT(reporter, rc0.get() == NULL); + REPORTER_ASSERT(reporter, !rc0); + + SkRefPtr rc1; + REPORTER_ASSERT(reporter, rc0 == rc1); + REPORTER_ASSERT(reporter, rc0 != r0); + + rc0 = r0; + REPORTER_ASSERT(reporter, rc0); + REPORTER_ASSERT(reporter, rc0 != rc1); + REPORTER_ASSERT(reporter, rc0 == r0); + + rc1 = rc0; + REPORTER_ASSERT(reporter, rc1); + REPORTER_ASSERT(reporter, rc0 == rc1); + REPORTER_ASSERT(reporter, rc0 == r0); + + rc0 = NULL; + REPORTER_ASSERT(reporter, rc0.get() == NULL); + REPORTER_ASSERT(reporter, !rc0); + REPORTER_ASSERT(reporter, rc0 != rc1); + + r0->unref(); +} + +/////////////////////////////////////////////////////////////////////////////// + #define kSEARCH_COUNT 91 static void test_search(skiatest::Reporter* reporter) { @@ -103,6 +144,7 @@ static void TestUTF(skiatest::Reporter* reporter) { test_utf16(reporter); test_search(reporter); + test_refptr(reporter); } #include "TestClassDef.h"