001// --- BEGIN LICENSE BLOCK ---
002/* 
003 * Copyright (c) 2009, Mikio L. Braun
004 * All rights reserved.
005 * 
006 * Redistribution and use in source and binary forms, with or without
007 * modification, are permitted provided that the following conditions are
008 * met:
009 * 
010 *     * Redistributions of source code must retain the above copyright
011 *       notice, this list of conditions and the following disclaimer.
012 * 
013 *     * Redistributions in binary form must reproduce the above
014 *       copyright notice, this list of conditions and the following
015 *       disclaimer in the documentation and/or other materials provided
016 *       with the distribution.
017 * 
018 *     * Neither the name of the Technische Universität Berlin nor the
019 *       names of its contributors may be used to endorse or promote
020 *       products derived from this software without specific prior
021 *       written permission.
022 * 
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
024 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
025 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
026 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
027 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
028 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
029 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
030 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
031 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
032 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
033 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
034 */
035// --- END LICENSE BLOCK ---
036
037package org.jblas;
038
039/**
040 * <p>General functions which are geometric in nature.</p>
041 * 
042 * <p>For example, computing all pairwise squared distances between all columns of a matrix.</p>
043 */
044public class Geometry {
045        
046        /**
047         * <p>Compute the pairwise squared distances between all columns of the two
048         * matrices.</p>
049         * 
050         * <p>An efficient way to do this is to observe that <i>(x-y)^2 = x^2 - 2xy - y^2</i>
051         * and to then properly carry out the computation with matrices.</p>
052         */
053        public static DoubleMatrix pairwiseSquaredDistances(DoubleMatrix X, DoubleMatrix Y) {
054                if (X.rows != Y.rows)
055                        throw new IllegalArgumentException(
056                                        "Matrices must have same number of rows");
057        
058                DoubleMatrix XX = X.mul(X).columnSums();
059                DoubleMatrix YY = Y.mul(Y).columnSums();
060        
061                DoubleMatrix Z = X.transpose().mmul(Y);
062                Z.muli(-2.0); //Z.print();
063                Z.addiColumnVector(XX);
064                Z.addiRowVector(YY);
065        
066                return Z;
067        }
068
069        /** Center a vector (subtract mean from all elements (in-place). */
070        public static DoubleMatrix center(DoubleMatrix x) {
071                return x.subi(x.mean());
072        }
073        
074        /** Center the rows of a matrix (in-place). */
075        public static DoubleMatrix centerRows(DoubleMatrix x) {
076                DoubleMatrix temp = new DoubleMatrix(x.columns);
077                for (int r = 0; r < x.rows; r++)
078                        x.putRow(r, center(x.getRow(r, temp)));
079                return x;
080        }
081        
082        /** Center the columns of a matrix (in-place). */
083        public static DoubleMatrix centerColumns(DoubleMatrix x) {
084                DoubleMatrix temp = new DoubleMatrix(x.rows);
085                for (int c = 0; c < x.columns; c++)
086                        x.putColumn(c, center(x.getColumn(c, temp)));
087                return x;
088        }
089        
090        /** Normalize a vector (scale such that its Euclidean norm is 1) (in-place). */
091        public static DoubleMatrix normalize(DoubleMatrix x) {
092                return x.divi(x.norm2());
093        }
094
095        /** Normalize the rows of a matrix (in-place). */
096        public static DoubleMatrix normalizeRows(DoubleMatrix x) {
097                DoubleMatrix temp = new DoubleMatrix(x.columns);
098                for (int r = 0; r < x.rows; r++)
099                        x.putRow(r, normalize(x.getRow(r, temp)));
100                return x;
101        }
102        
103        /** Normalize the columns of a matrix (in-place). */
104        public static DoubleMatrix normalizeColumns(DoubleMatrix x) {
105                DoubleMatrix temp = new DoubleMatrix(x.rows);
106                for (int c = 0; c < x.columns; c++)
107                        x.putColumn(c, normalize(x.getColumn(c, temp)));
108                return x;
109        }
110
111//BEGIN
112  // The code below has been automatically generated.
113  // DO NOT EDIT!
114        
115        /**
116         * <p>Compute the pairwise squared distances between all columns of the two
117         * matrices.</p>
118         * 
119         * <p>An efficient way to do this is to observe that <i>(x-y)^2 = x^2 - 2xy - y^2</i>
120         * and to then properly carry out the computation with matrices.</p>
121         */
122        public static FloatMatrix pairwiseSquaredDistances(FloatMatrix X, FloatMatrix Y) {
123                if (X.rows != Y.rows)
124                        throw new IllegalArgumentException(
125                                        "Matrices must have same number of rows");
126        
127                FloatMatrix XX = X.mul(X).columnSums();
128                FloatMatrix YY = Y.mul(Y).columnSums();
129        
130                FloatMatrix Z = X.transpose().mmul(Y);
131                Z.muli(-2.0f); //Z.print();
132                Z.addiColumnVector(XX);
133                Z.addiRowVector(YY);
134        
135                return Z;
136        }
137
138        /** Center a vector (subtract mean from all elements (in-place). */
139        public static FloatMatrix center(FloatMatrix x) {
140                return x.subi(x.mean());
141        }
142        
143        /** Center the rows of a matrix (in-place). */
144        public static FloatMatrix centerRows(FloatMatrix x) {
145                FloatMatrix temp = new FloatMatrix(x.columns);
146                for (int r = 0; r < x.rows; r++)
147                        x.putRow(r, center(x.getRow(r, temp)));
148                return x;
149        }
150        
151        /** Center the columns of a matrix (in-place). */
152        public static FloatMatrix centerColumns(FloatMatrix x) {
153                FloatMatrix temp = new FloatMatrix(x.rows);
154                for (int c = 0; c < x.columns; c++)
155                        x.putColumn(c, center(x.getColumn(c, temp)));
156                return x;
157        }
158        
159        /** Normalize a vector (scale such that its Euclidean norm is 1) (in-place). */
160        public static FloatMatrix normalize(FloatMatrix x) {
161                return x.divi(x.norm2());
162        }
163
164        /** Normalize the rows of a matrix (in-place). */
165        public static FloatMatrix normalizeRows(FloatMatrix x) {
166                FloatMatrix temp = new FloatMatrix(x.columns);
167                for (int r = 0; r < x.rows; r++)
168                        x.putRow(r, normalize(x.getRow(r, temp)));
169                return x;
170        }
171        
172        /** Normalize the columns of a matrix (in-place). */
173        public static FloatMatrix normalizeColumns(FloatMatrix x) {
174                FloatMatrix temp = new FloatMatrix(x.rows);
175                for (int c = 0; c < x.columns; c++)
176                        x.putColumn(c, normalize(x.getColumn(c, temp)));
177                return x;
178        }
179
180//END
181}