Showing posts with label Java. Show all posts
Showing posts with label Java. Show all posts

Saturday, September 29, 2012

Mallet and LibSVM

Mallet and LibSVM are the two machine learning libraries that I have been using the most. I felt the need of a way to directly use LibSVM from Mallet. As I mentioned in another post, I made a lightly refactored version of the Java implementation of LibSVM mainly for easy integration of custom kernel functions. Doing that gave me a better understanding of how LibSVM works and consequently helped me to integrate it with Mallet.
For classification tasks a Mallet instance pipe creates a FeatureVector out of an instance. So, it is quite straight forward to transform it into a format suitable for LibSVM. However, custom kernel functions that work on data structures other than vectors need to be handled differently. In the current version I have not kept any option for providing any arbitrary data structure from the Mallet end, however the code can be easily tweaked for that.
Mallet and LibSVM being separate libraries handle class labels differently. All I had to do in SVMClassifier is to align the class labels and scores from these two libraries. I have kept an option to tell LibSVM whether to predict probabilities or not which is required if you not only need the best class but also the scores given to the other classes.
If you are interested get it from github. Let me know if you have any suggestion.

Wednesday, September 12, 2012

Writing Custom Kernel Functions in Java for LibSVM

For my research on protein-protein interaction extraction I had to experiment with several different custom kernel functions. For that I looked into two most prevalent support vector machine libraries - SVMLight and LibSVM. In SVMLight one can plug in a custom kernel function through the kernel.h header file. LibSVM on the other hand does not allow custom kernel functions directly; however, one can pre-compute the kernel matrix (or Gram matrix) beforehand and feed it as input to the SVM. To me it seemed SVMLight would be the way to go. But then I found that LibSVM comes with an official Java implementation. I looked for a library that modifies that Java port to allow direct integration of kernel functions. I found jlibsvm which might have worked if I had found a little documentation in it. Then I decided to write a lightly refactored LibSVM on my own. Without much effort I have done that and am using it ever since. If you prefer to write your custom kernel functions in Java you can give it a try:
https://github.com/syeedibnfaiz/libsvm-java-kernel.git 

Writing a kernel function can not be easier. All you have to do is to implement the CustomKernel interface. Here is how you can write a linear kernel:
 /**  
  * <code>LinearKernel</code> implements a linear kernel function.  
  * @author Syeed Ibn Faiz  
  */  
 public class LinearKernel implements CustomKernel {  
   @Override  
   public double evaluate(svm_node x, svm_node y) {              
     if (!(x.data instanceof SparseVector) || !(y.data instanceof SparseVector)) {  
       throw new RuntimeException("Could not find sparse vectors in svm_nodes");  
     }      
     SparseVector v1 = (SparseVector) x.data;  
     SparseVector v2 = (SparseVector) y.data;  
     return v1.dot(v2);  
   }    
 }  

The kernel function you want to use should then be registered with the KernelManager. The following code snippet may give you a better idea of the whole work flow:
 public static void testLinearKernel(String[] args) throws IOException, ClassNotFoundException {  
     String trainFileName = args[0];  
     String testFileName = args[1];  
     String outputFileName = args[2];  
       
     //Read training file  
     Instance[] trainingInstances = DataFileReader.readDataFile(trainFileName);      
       
     //Register kernel function  
     KernelManager.setCustomKernel(new LinearKernel());      
       
     //Setup parameters  
     svm_parameter param = new svm_parameter();          
       
     //Train the model  
     System.out.println("Training started...");  
     svm_model model = SVMTrainer.train(trainingInstances, param);  
     System.out.println("Training completed.");              
       
     //Read test file  
     Instance[] testingInstances = DataFileReader.readDataFile(testFileName);  
     //Predict results  
     double[] predictions = SVMPredictor.predict(testingInstances, model, true);    
   }