from scipy import *

def dist(X,Y,W=0):
    '''Calculate the Mahalanobis distance between all pairs of points
    in two matrices.  Like the dist function from matlab\'s neural net
    toolbox, but with an extra weight matrix.  X is nxm, Y is mxp, W
    is mxm, and the returned matrix is nxp.
    '''
    (n,m) = shape(X)
    (m,p) = shape(Y)

    # Loop along the shorter dimension
    if(n > p):
        transp = 1
        (X,Y,n,p) = (Y,X,p,n)
    else:
        transp = 0

    D = zeros((n,p),'d');

    # If the user has not supplied a weight matrix, use the identity
    if(W == 0):
        W = eye(m)
        
    for i in range(n):
        M = transpose(resize(X[i,:], (p,m))) - Y
        D[i,:] = sum(M * dot(W, M))
        
    # Correct for switching X and Y if necessary
    if(transp):
        D = transpose(D)

    return D
