CodeIQ:素数で作る天秤ばかり

私自身が表題の問題を解いた時のプログラムについて解説します。
問題の詳細は「素数で作る天秤ばかり」(CodeIQ)を参照してください。

問題の概要

問題を引用します。
天秤ばかりを使って重さを量りたいと考えています。
ただし、使えるおもりは重さが素数のものしかありませんでした。

m, n をともに正の整数とし、m 以下の素数すべてがおもりの重さとして1つずつ用意されているとき、n グラムの計り方が何通りあるかを求めてください。

例えば、m = 10, n = 2のとき、2, 3, 5, 7 のおもりが一つずつありますので、左右に以下のおもりを使った4通りがあります。(量るものを左側に置いたとします。)

左側右側
なし[2]
[3][5]
[5][7]
[2,3][7]

標準入力から m と n がスペースで区切って与えられるとき、n グラムの計り方が何通りあるかを標準出力に出力してください。
(ただし、 m < 50 とします。)

【入出力サンプル】
Input
10 2

Output
4

私のプログラム

Rubyで解答しています。

#!/usr/bin/ruby

Primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47]

def toBit(arr)
	ret = 0
	for a in arr
		ret |= 1 << a
	end
	return ret
end

def getSumsRight(primes, sum)
	ret = {}

	first = []
	for v in primes
		first << [v]
		ret[v] = [toBit([v])]
	end

	cycle = [first]

	for i in 1...primes.size
		nlist = []

		for c in cycle[i-1]
			mx = c.max
			for v in primes.select{|x| mx < x}
				nxt = c + [v]
				s = nxt.reduce(:+)

				if s <= sum then
					nlist << nxt
					if ret[s] == nil then ret[s] = [toBit(nxt)]
					else ret[s] << toBit(nxt)
					end
				end
			end
		end

		if nlist.empty? then break
		else cycle << nlist
		end
	end

	return ret
end

def solve(m,n)
	cnt = 0

	primes = Primes.select{|i| i <= m}
	right = (primes.reduce(:+) + n) / 2		# 右の最大の重さ

	cond = getSumsRight(primes, right)
#	p cond
#	p cond.each{|k, lst|
#		for l in lst
#		 	printf("%d => %s\n", k, l.to_s(2))
#		end
#	 }

	cond.each{|sum_r, cond_r|
		sum_l = sum_r - n

		if sum_l == 0 then
			cnt += cond_r.size
		else
			for cr in cond_r
				cond_l = cond[sum_l]

				if cond_l == nil then next end

				for cl in cond_l
					if (cl & cr) == 0 then
						cnt += 1
					end
				end
			end
		end
	}

	return cnt
end

# main
while line = gets
	line.strip!
	if line.empty? then next end

	m,n = line.split.map{|a| a.to_i}
	p solve(m,n)
end

解説

難しいです。何か頭の良いやり方があると思うのですがわかりませんでした。
私のコードはロジック的にはかなり素朴な方法になっています。

考え方

基本的な考え方は次の通りです。
右側(量るものの乗っていない方)の重さとして取り得る値を列挙します。
この列挙した値には左側(量るものの乗っている方)として取り得る値も含んでいます。
なので、右側の候補値に対しその合計-nとなる値を左側の値として選び、左右の素数に重複のない組み合わせを数えます。

例えば入力値(m,n)=(10,2)の場合、

合計素数リスト
2[2]
3[3]
5[2,3],[5]
7[2,5],[7]
8[3,5]
9[2,7]
を作ります。そして、
右側が9([2,7])の時は左側の7([2,5],[7])を選び、7の素数リストはいずれも重複があるので除外、
右側が8([3,5])の時は左側の6はないので除外、
右側が7([2,5])の時は左側の5([2,3],[5])を選びずれも重複があるので除外、右側が([7])の時は左側の5([2,3],[5])を選びずれも重複がないのでOK、
……
という風にやって行くということです。

この方法で一番の問題点は上限をどうするかです。50以下の素数の全ての組み合わせについて列挙してしまえば簡単ですが、あまりにも無駄が多すぎます。
天秤が釣り合うということは左右の重さが同じになるということです。
なので、使える重りがx個あるとしたら、右側の重さの最大値はx個の重り全部の合計にnを足した値の半分を超えることはありません。
なので、使用可能な素数の組み合わせでこの値以下になるものを列挙すれば良いことになります。

次の問題点は候補同士を一つづつ素数の重複がないかをチェックするためいかにも遅いということです。これはビット演算を使用することで回避しました。

toBit()

引数は素数のリストです。
このリストに含まれる値のビット位置を1にした整数を返します。

getSumsRight()

右側に乗る重りのパターンを列挙して返す関数です。
引数primesは使用可能な素数のリスト、sumは右側に乗る最大の重さです。
戻り値は{合計値 => [素数リスト1, 素数リスト2, …], ……}形式の連想配列です。

まず、初期値を作ります。これはprimesから1個ずつ値を取った場合になります(16〜22行目)。retは戻り値で1個だけ値を選んだ場合はそれ自体が戻り値の要素になるのでこの時点でretに追加します。この時、素数リストはtoBit()で整数の配列からビットフラグに変換します。

初期値に値を追加することで候補値を作ります(24〜45行目)。全ての値を1回ずつ使ったら絶対に終わりなので最大でも1〜primesの要素数だけループしません。0始まりでないのは初期値のために最初の1回を済ませているからです。
一つ前のループ結果を取り出してさらにその要素だけループします(27〜40行目)。
全ての使用可能な素数を1回前に作った要素に加えます(29〜39行目)。この時、すでに使用した値の中で最大のもの以下の値は無視することで重複なく列挙できます。
素数リストに値を追加した(30行目)ら合計値を求め(31行目)ます。
合計値が引数sumを超えたら無視し(33〜38行目)、sum以下なら次の周回の候補値(34行目)と結果に追加(35〜37行目)します。
もし、次の候補値のリストが空ならループを抜けます(42〜44行目)。

solve()

引数は入力値のm,nです。
cntは結果の組み合わせ数です。

mから使用可能な素数リストを作ります(53行目)。
右側の候補値を作成します(54行目)。

候補値は合計値ごとなのでその単位でループします(64行目)。
sum_rは候補値の合計値でsum_lはそれに対応する左側の重りの重さの合計です(65行目)。
もし、sum_lが0なら右側にだけ重りを置けば良いことになるので、cond_rの要素数をcntに加えます(67〜68行目)。
それ以外の場合は左側にも重りが必要なので候補を検索します(69〜81行目)。
cond_rは例の場合で右側の合計値が7の場合は[100100,1000000]のような配列なので、それぞれについて処理します(70〜80行目)。ちなみに100100は5ビット目と2ビット目が1なので[2,5]を使って7にしていることを表しています(ビット位置は0始まりなので最下位ビットは常に0になります)。
cond_lは左側に乗る候補値のリストです(71行目)。
その重さを素数リストで構成できなかった場合は無視します(73行目)。
cond_lもcond_rと同じく配列なので要素の数だけループします(75〜79行目)。
右側の要素(cr)と左側の要素(cl)のビット位置に1が重複していなければ同じ素数を使用していないことになるので論理積を取って0になれば正しい組み合わせとしてカウントします(76〜78行目)。

最後まで処理したらcntの値が組み合わせ数になります。

雑感

絶対にもっとうまい方法があるように思えます。直感的には動的計画法でできそうなのですが、重りを左側に載せる場合、右側に載せる場合、使わない場合があって、途中で必要のないパターンを捨てるか作らない方法がわからないため計算量がとんでもないことになりそうなので諦めました。逆に引き算でやるのは作ってみたのですが速度的に全然ダメでした。
話は変わりますが、この前段階でビット演算を使わず、配列の差集合で重複をチェックするプログラムを書いています。多分遅くてダメだろうということは予測していましたが、結果から何かヒントが得られればという感じでした。やってみたらローカルで6秒あまりだったのでひょっとしたらビット演算にすればいけるかも、と思ってこのプログラムにしたというわけです。
ちなみにビット演算にしたことでローカルで0.4秒くらいになったので15倍ほど早くなったことになります。