#include <iostream>
#include <memory>
#include <string>
/// @brief 手写共享指针
/// @tparam T 指针内存类型
template <class T>
class shared_ptr_t {
private:
T* ptr_;
int* count_;
private:
/// @brief 释放内存
void release() {
if (ptr_ && (--(*count_)) == 0) {
delete ptr_;
delete count_;
}
}
/// @brief 根据传进来的指针,增加引用计数
/// @param sp 指针
void add_count(shared_ptr_t<T>& sp) {
ptr_ = sp.ptr_;
count_ = sp.count_;
if (ptr_) {
++(*count_);
}
}
public:
/// @brief 构造函数
/// @param ptr 指针
shared_ptr_t(T* ptr = nullptr) : ptr_(ptr) {
if (ptr_) {
count_ = new int(1);
}
}
/// @brief 拷贝构造函数
/// @param sp 指针
shared_ptr_t(shared_ptr_t& sp) { add_count(sp); }
/// @brief 赋值函数
/// @param sp 指针
/// @return shared_ptr_t<T>&
shared_ptr_t<T>& operator=(shared_ptr_t<T>& sp) {
if (this != &sp) {
release();
add_count(sp);
}
return *this;
}
/// @brief 析构函数
~shared_ptr_t() {
printf("destruct shared_ptr count: %d, 若解决循环引用的问题最后一行应该显示为1\n", *count_);
release();
}
/// @brief 重载*和->操作符
/// @return T&
T& operator*() { return *ptr_; }
T* operator->() { return ptr_; }
T* get() { return ptr_; }
int use_count() { return *count_; }
int* count() { return count_; }
};
/// @brief 手写弱指针
/// @tparam T 指针内存类型
template <class T>
class weak_ptr_t {
private:
int* count_;
T* ptr_;
public:
/// @brief 弱指针构造函数
/// @param ptr 传入的指针
weak_ptr_t(T* ptr = nullptr) : count_(nullptr), ptr_(nullptr) {}
/// @brief 拷贝构造函数,需要搭配手写shared_ptr_t使用,获取共享指针的count指针和ptr指针
/// @param sp 传入的指针
weak_ptr_t(shared_ptr_t<T>& sp) : count_(sp.count()), ptr_(sp.get()) {}
/// @brief 赋值函数
/// @param sp 传入的指针
/// @return weak_ptr_t<T>&
weak_ptr_t<T>& operator=(shared_ptr_t<T>& sp) {
count_ = sp.count();
ptr_ = sp.get();
return *this;
}
/// @brief lock函数
/// @return 若引用计数大于0,返回指针,否则返回空指针
T* lock() {
if (count_ && *count_ > 0) {
return ptr_;
}
return nullptr;
}
/// @brief 获取引用计数
/// @return int 引用计数
int use_count() { return *count_; }
/// @brief 获取引用计数指针
/// @return int* 引用计数指针
int* count() { return count_; }
};
class ListNode {
public:
int val;
#if false
shared_ptr_t<ListNode> next;
shared_ptr_t<ListNode> previous;
#else
weak_ptr_t<ListNode> next;
weak_ptr_t<ListNode> previous;
#endif
ListNode(int x) : val(x), next(nullptr), previous(nullptr) {}
};
int main() {
// 使用weak_ptr解决shared指针循环引用的问题
shared_ptr_t<ListNode> p3(new ListNode(800));
shared_ptr_t<ListNode> p4(new ListNode(800));
p3->next = p4;
p4->previous = p3;
printf("p3 use count: %d\n", p3.use_count());
printf("p4 use count: %d\n", p4.use_count());
}
共享指针循环引用使用弱指针解决
来自
标签:
发表回复