I have written earlier about faceted searching where each facet a document exposed represented a tag that was associated with the document. Of course, one of the most difficult aspects of setting up such a system is the setting up of the tags themselves. One way to build up the tag associations is to delegate it to the creator of the document, an approach taken by sites with user-generated content. Often, however, we have a large number of untagged documents, which we want to present as a searchable faceted collection. In such cases, we would have to assign the tags ourselves, which can be quite labor-intensive if we decide to do this manually.
One popular way to automatically tag documents is to use the Naive Bayes Classifier (BNC) algorithm. You can read about the math in the link, but basically BNC is based on the fact that if we know the probabilities of words appearing in a certain category of document, given the set of words in a new document, we can correctly predict if this new document is or is not that category of document.
I first heard of BNC from an ex-colleague who suggested the automatic tagging idea, extending what he understood about how SpamAssasin email spam filter works. Shortly thereafter, I read about it in Dr Dobb's Journal. But I never had the opportunity to actually use it until now.
I figured that since BNC seemed to be a useful algorithm, there would be open source implementation available on the web. I found a quite a few here. I looked through a few, but the only one I saw with halfway decent user-documentation was Classifier4J, so I chose that for my implementation of the automated tagger.
For my test data, I chose a collection of 21 articles I had written on my website years ago, and manually categorized into "Databases", "Web Development" and "Linux". The plan was to train a Bayesian Classifier instance with one match document from the target category and two non-match documents from the two other categories, then make it analyze all 21 documents. My initial implementation used the components provided in the classifier4j distribution - SimpleWordsDataSource for the words data source, the SimpleHTMLTokenizer for the tokenizer and the DefaultStopWordsProvider for the stop words provider.
However, the classification results were quite poor, and I wanted to find out why. I tried to build the package from source, but the project uses Maven 1.x which I am not familiar with, and I ended up building an empty jar file. I then tried to look at the words and their probabilities using the Eclipse debugger, but it did not give me any additional insights. So even though I try to avoid recreating functionality as much as possible, I ended up replacing most of the user-level components, depending only on classifier4j's core classes to do the probability calculations.
Usage
For convenience, I created the AutoTagger class, which is called from client code as follows:
1
2
3
4
5
6
7
8
9
10
11 | AutoTagger autoTagger = new AutoTagger();
autoTagger.setStopwordFile(new File("/path/to/my/stopwords.txt"));
autoTagger.setDataSource(new DriverManagerDataSource("com.mysql.jdbc.Driver",
"jdbc:mysql://localhost:3306/classifierdb", "user", "pass"));
autoTagger.addTrainingFile("database", databaseFilesArray);
autoTagger.addTrainingFile("web", webFilesArray);
autoTagger.addTrainingFile("linux", linuxFilesArray);
autoTagger.train();
double p = autoTagger.getProbabilityOfFileInCategory("database", someDbFile);
|
The AutoTagger internally contains references to a Map of Classifier objects keyed by category. The train() call will teach each of the Classifier the matched words for that category as well as the non-matches for all the other categories. The Bayesian classifier tends to produce probabilities that are either 0.01 to indicate no match and 0.99 to indicate a match.
The source for the AutoTagger class is shown below.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87 | public class AutoTagger {
private static final double CLASSIFICATION_CUTOFF_PROBABILITY = 0.5;
private File stopwordFile;
private DataSource dataSource;
private Map<String,BatchingBayesianClassifier> classifiers =
new HashMap<String,BatchingBayesianClassifier>();
private MultiMap categoryMap = new MultiHashMap();
public AutoTagger() {
super();
}
public void setStopwordFile(File stopwordFile) {
this.stopwordFile = stopwordFile;
}
public void setDataSource(DataSource dataSource) {
this.dataSource = dataSource;
}
public void addTrainingFiles(String category, File[] trainingFiles) {
for (File trainingFile : trainingFiles) {
categoryMap.put(category, trainingFile);
}
// if an instance of the classifier does not exist for category, create one
if (! classifiers.containsKey(category)) {
BatchingBayesianClassifier classifier = new BatchingBayesianClassifier(
new JdbcWordsDataSource(dataSource),
new CyberNekoHtmlTokenizer(DefaultTokenizer.BREAK_ON_WORD_BREAKS),
new FileDrivenStopWordProvider(stopwordFile));
classifiers.put(category, classifier);
}
}
@SuppressWarnings("unchecked")
public void train() throws WordsDataSourceException, ClassifierException, IOException {
List<String> categoryList = new ArrayList<String>();
categoryList.addAll(categoryMap.keySet());
// teach the classifiers in all categories
for (int i = 0; i < categoryList.size(); i++) {
String matchCategory = categoryList.get(i);
List<String> nonmatchCategories = new ArrayList<String>();
for (int j = 0; j < categoryList.size(); j++) {
if (i != j) {
nonmatchCategories.add(categoryList.get(j));
}
}
BatchingBayesianClassifier classifier = classifiers.get(matchCategory);
List<File> teachMatchFiles = (List<File>) categoryMap.get(matchCategory);
for (File teachMatchFile : teachMatchFiles) {
String trainingFileName = teachMatchFile.getName();
classifier.teachMatch(matchCategory, FileUtils.readFileToString(teachMatchFile, "UTF-8"));
classifiers.put(matchCategory, classifier);
for (String nonmatchCategory : nonmatchCategories) {
classifier.teachNonMatch(nonmatchCategory,
FileUtils.readFileToString(teachMatchFile, "UTF-8"));
classifiers.put(nonmatchCategory, classifier);
}
}
}
classifiers.clear();
}
public boolean isFileInCategory(String category, File file)
throws ClassifierException, WordsDataSourceException, IOException {
return getProbabilityOfFileInCategory(category, file) >= CLASSIFICATION_CUTOFF_PROBABILITY;
}
public double getProbabilityOfFileInCategory(String category, File file)
throws ClassifierException, WordsDataSourceException, IOException {
if (! classifiers.containsKey(category)) {
BatchingBayesianClassifier classifier = new BatchingBayesianClassifier(
new JdbcWordsDataSource(dataSource),
new CyberNekoHtmlTokenizer(DefaultTokenizer.BREAK_ON_WORD_BREAKS),
new FileDrivenStopWordProvider(stopwordFile));
classifiers.put(category, classifier);
}
BatchingBayesianClassifier classifier = classifiers.get(category);
if (classifier == null) {
throw new IllegalArgumentException("Unknown category:" + category);
}
return classifier.classify(category, FileUtils.readFileToString(file, "UTF-8"));
}
}
|
JdbcWordsDataSource
To be able to view (for debugging) the words that were being considered for the classification process, I needed to put them in a database. However, the provided JDBCWordsDataSource is very slow, because it tries to do an insert/update for each word that is not a stop word in the input document. I created a similar implementation of a JdbcWordsDataSource that will accumulate the inserts and updates until the entire document is read, then apply them all at once. It does the same thing during classification, by batching up all the words and issuing a single select call to get back all the word probability data. This produces a much more tolerable response time for the train() call (which is actually 3 calls, one teachMatch() and two teachNonMatch() calls in my case), and an almost instantaneous response for the classify() call. The code for my JdbcWordsDataSource is shown below:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164 | /**
* A Jdbc based implementation of ICategorisedWordsDataSource that can be
* independently trained using files.
*/
public class JdbcWordsDataSource implements ICategorisedWordsDataSource {
private JdbcTemplate jdbcTemplate;
private Map<String,Integer> wordCountMap = new HashMap<String,Integer>();
private Transformer quotingLowercasingTransformer = new Transformer() {
public Object transform(Object input) {
return "'" + StringUtils.lowerCase((String) input) + "'";
}
};
public JdbcWordsDataSource(DataSource dataSource) {
this.jdbcTemplate = new JdbcTemplate(dataSource);
}
public void addMatch(String word) throws WordsDataSourceException {
addMatch(ICategorisedClassifier.DEFAULT_CATEGORY, word);
}
public void addMatch(String category, String word) throws WordsDataSourceException {
addWord(word);
}
public void addNonMatch(String word) throws WordsDataSourceException {
addNonMatch(ICategorisedClassifier.DEFAULT_CATEGORY, word);
}
public void addNonMatch(String category, String word) throws WordsDataSourceException {
addWord(word);
}
public WordProbability getWordProbability(String word) throws WordsDataSourceException {
return getWordProbability(ICategorisedClassifier.DEFAULT_CATEGORY, word);
}
@SuppressWarnings("unchecked")
public WordProbability getWordProbability(String category, String word)
throws WordsDataSourceException {
int matchCount = 0;
int nonmatchCount = 0;
List<Map<String,Integer>> rows = jdbcTemplate.queryForList(
"select match_count, nonmatch_count " +
"from word_probability " +
"where word = ? and category = ?",
new String[] {word, category});
for (Map<String,Integer> row : rows) {
matchCount = row.get("MATCH_COUNT");
nonmatchCount = row.get("NONMATCH_COUNT");
break;
}
return new WordProbability(word, matchCount, nonmatchCount);
}
@SuppressWarnings("unchecked")
public WordProbability[] calcWordsProbability(String category, String[] words) {
List<WordProbability> wordProbabilities = new ArrayList<WordProbability>();
List<String> wordsList = Arrays.asList(words);
String query = "select word, match_count, nonmatch_count from word_probability where word in (" +
StringUtils.join(new TransformIterator(wordsList.iterator(), quotingLowercasingTransformer), ',') +
") and category=?";
List<Map<String,Object>> rows = jdbcTemplate.queryForList(query, new String[] {category});
for (Map<String,Object> row : rows) {
String word = (String) row.get("WORD");
int matchCount = (Integer) row.get("MATCH_COUNT");
int nonmatchCount = (Integer) row.get("NONMATCH_COUNT");
WordProbability wordProbability = new WordProbability(word, matchCount, nonmatchCount);
wordProbability.setCategory(category);
wordProbabilities.add(wordProbability);
}
return wordProbabilities.toArray(new WordProbability[0]);
}
public void initWordCountMap() {
wordCountMap.clear();
}
public void flushWordCountMap(String category, boolean isMatch) {
for (String word : wordCountMap.keySet()) {
int count = wordCountMap.get(word);
if (isWordInCategory(category, word)) {
updateWordMatch(category, word, count, isMatch);
} else {
insertWordMatch(category, word, count, isMatch);
}
}
}
@SuppressWarnings("unchecked")
public void removeDuplicateWords() {
List<Map<String,Object>> rows = jdbcTemplate.queryForList(
"select word, count(*) dup_count " +
"from word_probability " +
"group by word " +
"having dup_count > 1");
List<String> words = new ArrayList<String>();
for (Map<String,Object> row : rows) {
words.add((String) row.get("WORD"));
}
jdbcTemplate.update("delete from word_probability where word in (" +
StringUtils.join(new TransformIterator(words.iterator(), quotingLowercasingTransformer), ',') +
")");
}
private void addWord(String word) {
int originalCount = 0;
if (wordCountMap.containsKey(word)) {
originalCount = wordCountMap.get(word);
}
wordCountMap.put(word, (originalCount + 1));
}
/**
* Return true if the word is found in the category.
* @param category the category to look up
* @param word the word to look up.
* @return true or false
*/
@SuppressWarnings("unchecked")
private boolean isWordInCategory(String category, String word) {
List<Map<String,String>> rows = jdbcTemplate.queryForList(
"select word from word_probability where category = ? and word = ?",
new String[] {category, word});
return (rows.size() > 0);
}
/**
* @param category the category to update.
* @param word the word to update.
* @param isMatch if true, the word is a match for the category.
*/
private void updateWordMatch(String category, String word, int count, boolean isMatch) {
if (isMatch) {
jdbcTemplate.update(
"update word_probability set match_count = match_count + ? " +
"where category = ? and word = ?",
new Object[] {count, category, word});
} else {
jdbcTemplate.update(
"update word_probability set nonmatch_count = nonmatch_count + ? " +
"where category = ? and word = ?",
new Object[] {count, category, word});
}
}
/**
* @param category the category to insert.
* @param word the word to update.
* @param isMatch if true, the word is a match for the category.
*/
private void insertWordMatch(String category, String word, int count, boolean isMatch) {
if (isMatch) {
jdbcTemplate.update("insert into word_probability(" +
"category, word, match_count, nonmatch_count) values (?, ?, ?, 0)",
new Object[] {category, word, count});
} else {
jdbcTemplate.update("insert into word_probability(" +
"category, word, match_count, nonmatch_count) values (?, ?, 0, ?)",
new Object[] {category, word, count});
}
}
}
|
The JdbcWordsDataSource decouples the word accumulation and persistence into two separate methods, which need to be called by the classifier. The accumulation is all done in memory, and a flushWordCountMap() will actually persist the map into the database.
BatchingBayesianClassifier
In order to use the batching capability, I needed to create a subclass of BayesianClassifier that would only take this particular implementation, and override the database dependent methods in the parent. The BatchingBayesianClassifier is shown below:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 | /**
* Batches words for performance against the JdbcWordsDataSource. This is
* specific to this application's needs, so the constructor forces the caller
* to provide specific implementations of the super-class's ctor args.
*/
public class BatchingBayesianClassifier extends BayesianClassifier {
public BatchingBayesianClassifier(JdbcWordsDataSource wordsDataSource,
CyberNekoHtmlTokenizer tokenizer, FileDrivenStopWordProvider stopwordsProvider) {
super(wordsDataSource, tokenizer, stopwordsProvider);
}
protected boolean isMatch(String category, String input[]) throws WordsDataSourceException {
return (super.classify(category, input) > super.getMatchCutoff());
}
protected double classify(String category, String words[]) throws WordsDataSourceException {
List<String> nonStopwords = new ArrayList<String>();
FileDrivenStopWordProvider stopwordsProvider = (FileDrivenStopWordProvider) super.getStopWordProvider();
for (String word : words) {
if (stopwordsProvider.isStopWord(word)) {
continue;
}
nonStopwords.add(word);
}
JdbcWordsDataSource wds = (JdbcWordsDataSource) super.getWordsDataSource();
WordProbability[] wordProbabilities = wds.calcWordsProbability(category, nonStopwords.toArray(new String[0]));
return super.normaliseSignificance(super.calculateOverallProbability(wordProbabilities));
}
protected void teachMatch(String category, String words[]) throws WordsDataSourceException {
JdbcWordsDataSource wds = (JdbcWordsDataSource) super.getWordsDataSource();
wds.initWordCountMap();
super.teachMatch(category, words);
wds.flushWordCountMap(category, true);
}
protected void teachNonMatch(String category, String words[]) throws WordsDataSourceException {
JdbcWordsDataSource wds = (JdbcWordsDataSource) super.getWordsDataSource();
wds.initWordCountMap();
super.teachNonMatch(category, words);
wds.flushWordCountMap(category, false);
}
}
|
CyberNekoHtmlTokenizer
I also created my own implementation of the HTML Tokenizer using the NekoHTML parser from Cyberneko. This was because the SimpleHtmlTokenizer was crashing with the (admittedly bad and nowhere near spec-compliant) HTML in the documents. Cyberneko's NekoHTML parser is more forgiving, and I was able to pull out the body of my HTML document with the following implementation:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 | public class CyberNekoHtmlTokenizer extends DefaultTokenizer {
public CyberNekoHtmlTokenizer() {
super();
}
public CyberNekoHtmlTokenizer(int tokenizerConfig) {
super(tokenizerConfig);
}
/**
* Uses the Cyberneko HTML parser to parse out the body content from the
* HTML file as a stream of text.
* @see net.sf.classifier4J.ITokenizer#tokenize(java.lang.String)
*/
public String[] tokenize(String input) {
return super.tokenize(getBody(input));
}
public String getBody(String input) {
try {
DOMParser parser = new DOMParser();
parser.parse(new InputSource(new ByteArrayInputStream(input.getBytes())));
Document doc = parser.getDocument();
NodeList bodyTags = doc.getElementsByTagName("BODY");
if (bodyTags.getLength() == 0) {
throw new Exception("No body tag in this HTML document");
}
Node bodyTag = bodyTags.item(0);
return bodyTag.getTextContent();
} catch (Exception e) {
throw new RuntimeException("HTML Parsing failed on this document", e);
}
}
}
|
FileDrivenStopWordProvider
The DefaultStopWordProvider contained a hard coded array of stop words, which was pretty basic, so I built one to work off a file (the contents of which I scraped from the classifier4j message board, btw), which also treats numbers as stopwords. The code for the FileDrivenStopWordProvider is shown below:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 | public class FileDrivenStopWordProvider implements IStopWordProvider {
private SortedSet<String> words = new TreeSet<String>();
public FileDrivenStopWordProvider(File stopWordFile) {
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(stopWordFile)));
String word;
while ((word = reader.readLine()) != null) {
words.add(StringUtils.lowerCase(word.trim()));
}
} catch (FileNotFoundException e) {
LOGGER.error("File:" + stopWordFile.getName() + " does not exist", e);
} catch (IOException e) {
LOGGER.error("Error reading file:" + stopWordFile.getName(), e);
}
}
public boolean isStopWord(String word) {
return words.contains(StringUtils.lowerCase(word.trim())) || StringUtils.isNumeric(word);
}
}
|
Results
I ran the AutoTagger in two scenarios. The first was with low training, where I took the first file that was created in each category, and trained the classifiers with them, then ran the rest of the files against the trained classifiers. The assumption was that I knew what I was doing when classifying the first article, rather than attempt to shoehorn an article into an existing category set. The results from the run is shown below. The rows in gray indicate the files which were used for training.
File name |
Orig. class |
P(database) |
P(web) |
P(linux) |
Tags |
artdb001 |
database |
0.99 |
0.01 |
0.01 |
database |
artdb002 |
database |
0.99 |
0.01 |
0.01 |
database |
artdb003 |
database |
0.01 |
0.01 |
0.01 |
(none) |
artdb005 |
database |
0.01 |
0.01 |
0.01 |
(none) |
artdb006 |
database |
0.01 |
0.01 |
0.01 |
(none) |
artdb007 |
database |
0.01 |
0.01 |
0.01 |
(none) |
artwb001 |
web |
0.01 |
0.99 |
0.01 |
web |
artwb002 |
web |
0.01 |
0.01 |
0.01 |
(none) |
artwb003 |
web |
0.01 |
0.01 |
0.01 |
(none) |
artwb004 |
web |
0.01 |
0.01 |
0.01 |
(none) |
artwb005 |
web |
0.01 |
0.01 |
0.01 |
(none) |
artwb006 |
web |
0.01 |
0.01 |
0.01 |
(none) |
artwb007 |
web |
0.01 |
0.01 |
0.01 |
(none) |
artli001 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
artli002 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
artli003 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
artli004 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
artli005 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
artli006 |
linux |
0.01 |
0.01 |
0.99 |
linux |
artli007 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
artli008 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
As you can see, the results are not too great. Almost none of the documents besides the ones used for training were matched. This could be because of the paucity of training data. To rectify the situation, I created a high training scenario, where all but one of the files in each category is used for the training, then the trained classifiers are let loose on that one remaining file to see what category it is. The results for this test is shown below:
File name |
Orig. class |
P(database) |
P(web) |
P(linux) |
Tags |
artdb001 |
database |
0.99 |
0.01 |
0.01 |
database |
artdb002 |
database |
0.99 |
0.01 |
0.01 |
database |
artdb003 |
database |
0.99 |
0.01 |
0.01 |
database |
artdb005 |
database |
0.01 |
0.01 |
0.01 |
(none) |
artdb006 |
database |
0.99 |
0.99 |
0.01 |
database, web |
artdb007 |
database |
0.01 |
0.01 |
0.01 |
(none) |
artwb001 |
web |
0.01 |
0.99 |
0.01 |
web |
artwb002 |
web |
0.01 |
0.99 |
0.01 |
web |
artwb003 |
web |
0.01 |
0.01 |
0.01 |
(none) |
artwb004 |
web |
0.01 |
0.01 |
0.01 |
(none) |
artwb005 |
web |
0.01 |
0.99 |
0.01 |
web |
artwb006 |
web |
0.01 |
0.99 |
0.01 |
web |
artwb007 |
web |
0.99 |
0.99 |
0.01 |
database, web |
artli001 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
artli002 |
linux |
0.01 |
0.01 |
0.99 |
linux |
artli003 |
linux |
0.01 |
0.01 |
0.01 |
(none) |
artli004 |
linux |
0.99 |
0.99 |
0.99 |
database, web, linux |
artli005 |
linux |
0.01 |
0.01 |
0.99 |
linux |
artli006 |
linux |
0.01 |
0.01 |
0.99 |
linux |
artli007 |
linux |
0.01 |
0.01 |
0.99 |
linux |
artli008 |
linux |
0.01 |
0.01 |
0.99 |
linux |
The results are better than the first one, but it still misses a few. A surprising finding is that it finds that some articles can belong to multiple categories. Not so surprising, if you think that its the same person writing all three types, so a web article could involve a database, or a linux article could describe a database or webserver installation.
Conclusion
The BNC algorithm probably works best when there is much more training data available than what I provided it, and where the documents are more stratified, for example, politics versus technology, so there is less chance of overlapping words in each category. In my case, it does detect some things, but the results can probably be improved by providing more training data or pruning the words in the database after the training is complete and before classification is done.