C++ Shared Pointer

Let’s write a C++ shared pointer. Shared pointers are smart pointers that ensure the freeing of memory when the last reference to it is deleted.

Fundamentally, smart pointers take advantage of a class’ destructor. When an object goes out of scope, its destructor is called. When a shared pointer goes out of scope, logic in its destructor checks to see if it is the last reference to some data that it is pointing to. If it is, it will free the memory. In this way, it’s similar to RAII.

#include <iostream>                                                                                                                                                                                              // A simple struct to create a shared pointer to                                                                                                                                                                                              
struct Entry {                                                                                                                                                                                                                                
    Entry(int idx, int val) :                                                                                                                                                                                                                 
        m_index(idx),                                                                                                                                                                                                                         
        m_value(val) {}                                                                                                                                                                                                                       
                                                                                                                                                                                                                                              
    int m_index;                                                                                                                                                                                                                              
    int m_value;                                                                                                                                                                                                                              
};
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
class RefCount {                                                                                                                                                                                                                              
public:                                                                                                                                                                                                                                       
    void increment() {                                                                                                                                                                                                                        
        m_referenceCount++;                                                                                                                                                                                                                   
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
    void decrement() {                                                                                                                                                                                                                        
        m_referenceCount--;                                                                                                                                                                                                                   
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
    void operator++() {                                                                                                                                                                                                                       
        increment();                                                                                                                                                                                                                          
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
    void operator--() {                                                                                                                                                                                                                       
        decrement();                                                                                                                                                                                                                          
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
    bool hasReference() {                                                                                                                                                                                                                     
        return m_referenceCount > 0;                                                                                                                                                                                                          
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
private:                                                                                                                                                                                                                                      
    int m_referenceCount = 1;                                                                                                                                                                                                                 
};

template<typename DT>                                                                                                                                                                                                                         
class SharedPointer {                                                                                                                                                                                                                         
public:                                                                                                                                                                                                                                       
    /*                                                                                                                                                                                                                                        
     *   Conversion Constructor                                                                                                                                                                                                               
     *   e.g.                                                                                                                                                                                                                                 
     *   SharedPointer<DT> data;                                                                                                                                                                                                              
     *   SharedPointer<DT> otherData = data; // called here                                                                                                                                                                                   
     */                                                                                                                                                                                                                                       
    SharedPointer(const SharedPointer<DT>& other) :                                                                                                                                                                                           
        m_data(other.m_data),                                                                                                                                                                                                                 
        m_refCount(other.m_refCount) {                                                                                                                                                                                                        
            m_refCount->increment();                                                                                                                                                                                                          
        }                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                              
    SharedPointer() :                                                                                                                                                                                                                         
        m_data(nullptr),                                                                                                                                                                                                                      
        m_refCount(nullptr) {}                                                                                                                                                                                                                
                                                                                                                                                                                                                                              
    SharedPointer(DT* data, RefCount* refCount) :                                                                                                                                                                                             
        m_data(data),                                                                                                                                                                                                                         
        m_refCount(refCount) {}                                                                                                                                                                                                               
                                                                                                                                                                                                                                              
    ~SharedPointer() {                                                                                                                                                                                                                        
        decrementAndDelete();                                                                                                                                                                                                                 
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
    /*                                                                                                                                                                                                                                        
     *   Assignemnt Operator                                                                                                                                                                                                                  
     *   SharedPointer<DT> data;                                                                                                                                                                                                              
     *   SharedPointer<DT> otherData;                                                                                                                                                                                                         
     *   otherData = data; // called here                                                                                                                                                                                                     
     */                                                                                                                                                                                                                                       
    SharedPointer<DT>& operator=(SharedPointer<DT>& other){                                                                                                                                                                                   
        std::cout << "assignment operator" << std::endl;                                                                                                                                                                                                                                                                                                                                                                                                      
                                                                                                                                                                                                                       
        m_data = other.getData();                                                                                                                                                                                                             
        m_refCount = other.getData();                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
        m_refCount->increment();                                                                                                                                                                                                              
    }
                                                                                                                                                                                                                                              
    DT& operator*() {                                                                                                                                                                                                                         
        return *m_data;                                                                                                                                                                                                                       
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
    DT* getData() {                                                                                                                                                                                                                           
        return m_data;                                                                                                                                                                                                                        
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
    RefCount* getRefCount() {                                                                                                                                                                                                                 
        return m_refCount;                                                                                                                                                                                                                    
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
private:                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                              
    void decrementAndDelete() {                                                                                                                                                                                                               
        if(m_refCount) {                                                                                                                                                                                                                      
            m_refCount->decrement();                                                                                                                                                                                                          
                                                                                                                                                                                                                                              
            if(!m_refCount->hasReference()) {                                                                                                                                                                                                 
                delete m_data;                                                                                                                                                                                                                
                delete m_refCount;                                                                                                                                                                                                            
                m_data = nullptr;                                                                                                                                                                                                             
                m_refCount = nullptr;                                                                                                                                                                                                         
            }                                                                                                                                                                                                                                 
        }                                                                                                                                                                                                                                     
    }                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                              
    DT* m_data;                                                                                                                                                                                                                               
    RefCount* m_refCount;                                                                                                                                                                                                                     
};
                                                                                                                                                                                                                                            
template<typename DT>                                                                                                                                                                                                                         
SharedPointer<DT> make_shared_pointer(DT* data) {                                                                                                                                                                                             
    RefCount* rc = new RefCount();                                                                                                                                                                                                            
    return SharedPointer<DT>(data, rc);                                                                                                                                                                                                       
}                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                              
int main(int argc, char* argv[]) {                                                                                                                                                                                                            
                                                                                                                                                                                                                                              
    Entry* e = new Entry(1,1);                                                                                                                                                                                                                
    SharedPointer<Entry> sp = make_shared_pointer<Entry>(e);                                                                                                                                                                                  
                                                                                                                                                                                                                                              
    std::cout << (*sp).m_value << std::endl;                                                                                                                                                                                                  
                                                                                                                                                                                                                                              
    SharedPointer<Entry> sp_2 = sp;                                                                                                                                                                                                           
                                                                                                                                                                                                                                              
    std::cout << (*sp_2).m_value << std::endl;                                                                                                                                                                                                
                                                                                                                                                                                                                                              
    return 0;                                                                                                                                                                                                                                 
}
C++ Shared Pointer

Leave a Reply

Your email address will not be published. Required fields are marked *

Scroll to top