Sum of pairwise Hamming Distance | Same problem | O(n)


#1
    int Solution::cntBits(vector<int> &A) {
        
        long long int ans=0,n=A.size();
        long long int count;
        for(int i=0;i<31;i++){
            count=0;
            for(int j=0;j<n;j++){
                if(A[j]&(1<<i))
                 count++;
            }
            ans=(ans+count*(n-count)*2)%1000000007;
        }
        return ans;
    }

#3

please explain this : count*(n-count)*2
what is it doing


#4
if(A[j]&(1<<i))
                 count++; 

We are incrementing count whenever (A[j] & (1<<i)) == 1, that means count is the number of values out of n, that have 1 at that bit position ā€˜iā€™. Now, n - count is the number of zeros at that bit position.
So, For total no of pairs we can do 2*(count)*(n-count).
Ex: Lets say n = 3 , count of ones = 1, and n-count = 3-1=2
Now pairs are:
1,0 ; 1,0 ; 0,1 ; 0,1 // 4 pairs
A total of 2(1)(2) = 4.