实现要点

  • shared_ptr 的引用计数应当使用原子操作(std::atomic)确保最低限度的并发安全。
  • 多个指向同一个底层指针对象的shared_ptr,它们的引用计数应当相等,因此应该将引用计数也定义成指针。
  • 使用移动赋值运算时,被移动的shared_ptr的引用计数应当减1,同时要判断是否释放资源(引用计数为0时)。
  • 使用移动构造函数时,要将被移动的对象的底层指针和引用计数器置空。
  • 重载解引用 * 和箭头 -> 运算符,实现类似普通指针的调用
  • 使用RAII机制自动释放资源

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
template <typename T>
class SharedPtr
{
private:
T *ptr;
std::atomic<int> *count;

public:
explicit SharedPtr(T *p = nullptr) : ptr(p), count(new std::atomic<int>(1))
{
if (p == nullptr)
{
count->store(0);
}
}

SharedPtr(const SharedPtr<T> &other) : ptr(other.ptr), count(other.count)
{
count->fetch_add(1);
}

SharedPtr(SharedPtr<T> &&other) noexcept : ptr(other.ptr), count(other.count)
{
other.ptr = nullptr;
other.count = nullptr;
}

SharedPtr<T> &operator=(const SharedPtr<T> &other)
{
if (this->ptr != other.ptr)
{
if (count->fetch_sub(1) == 0)
{
delete ptr;
delete count;
}
this->ptr = other.ptr;
this->count = other.count;
if (this->ptr)
{
count->fetch_add(1);
}
}
}

SharedPtr<T> &operator=(SharedPtr<T> &&other)
{
if (this->ptr != other.ptr)
{
if (count->fetch_sub(1) == 0)
{
delete ptr;
delete count;
}
this->ptr = other.ptr;
this->count = other.count;
other.ptr = nullptr;
other.count = nullptr;
}
}

~SharedPtr()
{
if (count->fetch_sub(1) == 0)
{
delete ptr;
delete count;
}
}

T *operator->() const
{
return ptr;
}

T &operator*() const
{
return *ptr;
}

T *get() const noexcept
{
return ptr;
}

int use_count() const noexcept
{
return count->load();
}

bool unique() const noexcept
{
return count->load() == 1;
}

void swap(SharedPtr<T> &other) noexcept
{
std::swap(ptr, other.ptr);
std::swap(count, other.count);
}

void reset(T *p = nullptr) noexcept
{
if (this->ptr != p)
{
if ((*count).fetch_sub(1) == 0)
{
delete ptr;
delete count;
}
ptr = p;
if (p != nullptr)
{
count = new std::atomic<int>(1);
}
}
}
};