CodeIQ:アイテム類似度のレコメンド

私自身が表題の問題を解いた時のプログラムについて解説します。
問題の詳細は「アイテム類似度のレコメンド」(CodeIQ)を参照してください。

問題の概要

問題を引用します。
複数のユーザのアイテムの購入履歴が与えられた時に、他のユーザの履歴を利用して、ユーザに対しておすすめを表示したり、閲覧しているアイテムに対して関連するアイテムを表示する技術は多くのウェブサイトで採用され重要な技術となっている。
本問では最も基本的な手法である類似度に基づいたアイテム型のレコメンデーションを対象にする。
下はユーザとアイテムの評価行列の一例である。

アイテム1アイテム2アイテム3アイテム4アイテム5
ユーザ131233
ユーザ243435
ユーザ333154
ユーザ415521

類似度の指標については様々なものがあるが、ここでは類似度指標としてコサイン類似度を用いる。また二つのアイテム間の類似度を扱う。
アイテムaとアイテムbのコサイン類似度は以下で表現される(abはベクトル)。
Sim(a,b) = (ab)/(|a||b|)

例えば、アイテム1とアイテム5の類似度は以下で表現される。
(3*3+4*5+3*4+1*1)/{√(32+42+32+12) * √(32+52+42+12)}

与えられたアイテムに対して、最もコサイン類似度が高い(自身を除く)上位3件を出力するプログラムを書いてください。

【入力】
1行目:ユーザー数nUSER、アイテム数nITEM、レーティング数N
2〜(N+1)行目:ユーザー(1〜nUSER)、アイテム(1〜nITEM)、レーティング(1〜5)
 ※レーティングが存在しない場合、0としてください。
(N+2)行目:レコメンドを出力すべきアイテム数M
(N+3)〜(N+M+2)行目:レコメンドを出力すべきアイテム

【出力】
M行に渡り、「アイテム番号、類似度が最も高いアイテム番号、類似度が2番目に高いアイテム番号、類似度が3番目に高いアイテム番号」を空白区切りで出力してください。

【制約】
1<=nUSER<=500
1<=nITEM<=500
1<=N<=20000
1<=M<=35

私のプログラム

Javaで解答しています。

import java.io.*;
import java.util.*;


public class Main{
	static int[][] DataArray;	// 入力データの行列 (行:アイテム番号、列:ユーザ番号) 0始まり
	static int nItem;	// アイテム数
	static int nUser;	// ユーザ数
	static int nRating;	// レーティング数
	static int retItemNum;	// 類似度を出力するアイテム数
	static int[] retItems;	// 類似度を出力するアイテム番号 0始まり

	static double[] ItemVecLenArray;
	static double[][] SimCosArray;

	/**
	 *	データ数情報をセットする
	 *	@param line 入力行(ユーザ数 アイテム数 レーティング数)
	 */
	public static void parseDataNum(String line){
		String[] splited = line.split(" ");
		int[] ret = new int[splited.length];
		for(int i=0; i<splited.length; i++){
			ret[i] = Integer.parseInt(splited[i]);
		}

		nUser = ret[0];
		nItem = ret[1];
		nRating = ret[2];
	}

	/**
	 *	データの配列を作成する
	 *	行番号はアイテム番号(入力値は1始まり、実装は0始まり)
	 *	列番号はユーザ番号(入力値は1始まり、実装は0始まり)
	 *	※ 問題の表とは行列が入れ替わる
	 */
	public static void createDataArray(){
		DataArray = new int[nItem][];
		for(int i=0; i<nItem; i++){
			DataArray[i] = new int[nUser];
		}
	}

	/**
	 *	データ行列に値を設定する
	 *	@param line 入力値(ユーザ番号 アイテム番号 レーティング)
	 */
	public static void setDataArray(String line){
		String[] splited = line.split(" ");
		int[] ret = new int[splited.length];
		for(int i=0; i<splited.length; i++){
			ret[i] = Integer.parseInt(splited[i]);
		}

		int u = ret[0]-1;	// ユーザ番号
		int i = ret[1]-1;	// アイテム番号
		int r = ret[2];	// レコメンド

		DataArray[i][u] = r;
	}

	/**
	 *	データ行列を表示する(デバッグ)
	 */
	public static void printDataArray(){
		for(int[] l: DataArray){
			for(int i: l){
				System.out.print(i);
				System.out.print(" ");
			}
			System.out.println("");
		}
	}

	/**
	 *	結果を表示するアイテム番号のリストを作る
	 */
	public static void createRetItems(String line){
		retItemNum = Integer.parseInt(line);
		retItems = new int[retItemNum];
	}

	/**
	 *	結果を表示するアイテム番号のリストにアイテム番号をセットする
	 *	@param line 入力値(0始まりに直す)
	 *	@param n 何番目に入力された値か(0始まり)
	 */
	public static void setRetItems(String line, int n){
		int i = Integer.parseInt(line);
		retItems[n] = i-1;
	}

	/**
	 *	アイテム番号のリストを表示する(デバッグ)
	 */
	public static void printRetItems(){
		for(int i: retItems){
			System.out.print(i);
			System.out.print(" ");
		}
		System.out.println("");
	}

	/**
	 *	各アイテムの√Σ(レーティング^2)を計算して表にしておく
	 */
	public static void calcItemVecLenArray(){
		ItemVecLenArray = new double[nItem];
		for(int j=0; j<DataArray.length; j++){
			long s = 0;
			for(int i: DataArray[j]){
				s += i*i;
			}
			ItemVecLenArray[j] = Math.sqrt(s);
		}
	}

	/**
	 *	各アイテムの√Σ(レーティング^2)を表示する
	 */
	public static void printItemVecLenArray(){
		for(double d: ItemVecLenArray){
			System.out.println(d);
		}
	}

	/**
	 *	コサイン類似度の計算結果行列領域を作成する
	 */
	public static void createSimCosArray(){
		SimCosArray = new double[nItem][];
		for(int i=0; i<nItem; i++){
			SimCosArray[i] = new double[nItem];
		}
	}

	/**
	 *	コサイン類似度の計算結果行列領域を表示する(デバッグ)
	 */
	public static void printSimCosArray(){
		for(double[] l: SimCosArray){
			for(double d: l){
				System.out.printf("%4f ", d);
			}
			System.out.println("");
		}
	}

	/**
	 *	Σ(a*b)を計算する
	 */
	private static double calcAB(int[] a, int[] b){
		double s = 0;
		for(int i=0; i<nUser; i++){
			s += a[i] * b[i];
		}
		return s;
	}

	/**
	 *	コサイン類似度を全アイテムに対して計算し、結果を表に収める
	 */
	public static void simCos(){
		for(int i=0; i<nItem; i++){
			for(int j=i+1; j<nItem; j++){
				int[] a = DataArray[i];
				int[] b = DataArray[j];

				SimCosArray[i][j] = calcAB(a, b) / (ItemVecLenArray[i] * ItemVecLenArray[j]);
				SimCosArray[j][i] = SimCosArray[i][j];
			}
		}
	}

	/**
	 *	指定されたアイテムと類似度の高いもの3個の番号(1始まり)を表示する
	 *	@param index 対象のアイテム番号(0始まり)
	 */
	private static void printMax3(int index){
		double[] l = SimCosArray[index];
		double max[] = new double[3];
		int max_i[] = new int[3];

		for(int i=0; i<l.length; i++){
			if(max[0] < l[i]){
				max[2] = max[1];
				max_i[2] = max_i[1];
				max[1] = max[0];
				max_i[1] = max_i[0];
				max[0] = l[i];
				max_i[0] = i;
			}
			else if(max[1] < l[i]){
				max[2] = max[1];
				max_i[2] = max_i[1];
				max[1] = l[i];
				max_i[1] = i;
			}
			else if(max[2] < l[i]){
				max[2] = l[i];
				max_i[2] = i;
			}
		}

		String s = String.valueOf(index + 1) + " ";
		for(int i=0; i<3; i++){
			s += String.valueOf(max_i[i]+1) + " ";
		}
		s = s.trim();
		System.out.println(s);
	}

	/**
	 *	結果を表示する
	 */
	public static void printResult(){
		for(int i: retItems){
			printMax3(i);
		}
	}

	public static void main(String args[]) throws IOException{
		try(BufferedReader br = new BufferedReader(new InputStreamReader(System.in))){
			String line;

			for(int i=0; (line = br.readLine()) != null; i++){
				line = line.trim();

				if(i == 0){
					parseDataNum(line);
					createDataArray();
				}
				else if(i <= nRating){
					setDataArray(line);
				}
				else if(i == nRating+1){
					createRetItems(line);
				}
				else{
					setRetItems(line, i-(nRating+2));
				}
			}

			calcItemVecLenArray();
			createSimCosArray();
			simCos();
			printResult();
		}
	}
}

解説

★★★★の問題ですが私はあまり難しいとは思いませんでした。
問題の説明は詳しいですし、ロジックは明確に示されていてその通りに実装すれば良いと思います。

入力

入力が複雑なので面倒です。
parseDataNum()で1行目の情報を処理します。
createDataArray()は1行めで読み込んだユーザ数、アイテム数を元にアイテムごとのユーザによるレーティングを2次元配列に保持します。
setDataArray()でユーザごとのアイテムに対するレーティングをパースし、DataArrayに値をセットします。
createRetItems()はレコメンドを出力すべきアイテム数Mに従ってその領域(retItems)を確保します。
setRetItems()で結果を表示すべきアイテム番号をretItemsに記録します。

DataArray

入力されたユーザごとのアイテムへのレーティングを記録します。
問題の表をプログラムで表現したものですが、行と列を入れ替えています。

calcItemVecLenArray()

この問題で唯一工夫したと言える部分です。
アイテムaとアイテムbのコサイン類似度は以下で表現されます(abはベクトル)。
Sim(a,b) = (ab)/(|a||b|)

このうち|a|や|b|はあらかじめ1回だけ計算しておけば結果を使い回すことができます。なのでそれを計算してしまい、ItemVecLenArrayに保持します。
ItemVecLenArrayは要素番号(0始まり)ごとに各アイテムの計算結果を保持する配列です。

createSimCosArray()

このメソッドで全てのアイテムの組み合わせの計算結果を保持する領域(SimCosArray)を確保します。サイズはnItem×nItemのに次元配列です。

simCos()

全てのアイテムの組み合わせでコサイン類似度を計算します。
アイテムごとのレーティングはDataArrayの対応する要素番号(0始まり)にアクセスすれば取得できます(行をアイテムにしておいたので単に行を取り出せば済みます)。
分母の|a|や|b|は計算済みなのでItemVecLenArrayから対応する要素番号を取り出せばOKです。
さらに、アイテムiとアイテムjの類似度はアイテムjとアイテムiの類似度と同じなので計算量は半分にできます。ただし結果はアイテム番号でそのアイテムと他の全アイテムの類似度を全て取得したいので171行目の処理でiとjを入れ替えた場所にも同じ値を保持します。

結果出力

printResult()で入力で指定されたアイテム分ループし、その中でprintMax3()を呼んでそのアイテムの類似度上位3件を表示します。
上位3件なので面倒臭いですが、指定されたアイテムの計算結果を先頭からチェックし、大きなものを選んでいるだけです。

雑感

解答を書いて投稿する段では「これで時間切れになったら改善する方法を思いつかない(せいぜいItemVecLenArrayを入力を読みながら作るくらい)」と思って緊張しましたが、余裕でパスしてどこが難しかったんだろうというのが感想です。