CodeIQ:1の並びで小さい方から

私自身が表題の問題を解いた時のプログラムについて解説します。
問題の詳細は「1の並びで小さい方から」(CodeIQ)を参照してください。

問題の概要

問題を引用します。
数を2進数で表します。
このとき、1 の連なりの最大の長さをF(n)と書きます。
例えば、
 21(10進数) = 10101(2進数) なので、F(21)=1
 45(10進数) = 101101(2進数) なので、F(45)=2
 23(10進数) = 10111(2進数) なので、F(23)=3
 504(10進数) = 111111000(2進数) なので、F(504)=6
です。

さて、X と Y という2つの数を与えます。
1以上の整数 n について、F(n)=X である n のうち、小さい方から Y 番目の数を求めてください。

私のプログラム

Pythonで解答しています。

#!/usr/local/bin/python3

import fileinput
from functools import reduce

'''
	結果を求める。
	ロジックは次の通り(X=2の場合)

	1)	0b11を作る
	2)	結果集合を作る((1)の値を初期値とする)
	3)	結果集合から値を1つ取り出す
		値の左右に0か1を付け加えた新たな値を作成する
		(ex: 0b11 -> 0b011, 0b111, 0b110, 0b111)
		1がX個以上連続するものを排除する。
			作成した値を(1)より1ビット長いマスク(0b111)でマスクしてマスクと同じ値を除く
			値作成時に左側に値を追加した場合は上位ビットを、右側に値を追加した場合は下位ビット
			をマスクする。(ex: 2周目の場合0b1110と0b0111をマスクに使う)
		排除されなかった値を結果集合に加える
	4)	(3)を結果集合の要素数がYに達するまで繰り返す

	@param X 連続するビット長
	@param Y 何番目の値か(1始まり)
	@return 結果集合
'''
def solve(X,Y):
	# 一番小さな値を作る(X=3なら0b111)
	base = reduce(lambda a,b: a | (1<<b), range(0,X), 1)

	s = set([base])	# 結果集合に初期値を設定
	l = 0		# whileループカウンタ
	while len(s) < Y:
		RMASK = (base << 1) | 1
		LMASK = RMASK << l
		ns = set()

		for v in s:
			r0v = (v<<1)
			if (r0v & RMASK) != RMASK:
				ns.add(r0v)

			r1v = (v<<1) | 1
			if (r1v & RMASK) != RMASK:
				ns.add(r1v)

			l0v = v
			if (l0v & LMASK) != LMASK:
				ns.add(l0v)

			l1v = v | (1<<(X+l))
			if (l1v & LMASK) != LMASK:
				ns.add(l1v)

		s = ns
		l += 1

	return s

if __name__ == "__main__":
	for line in fileinput.input():
		if not line.strip():
			continue

		X,Y = map(int, line.strip().split(","))

	l = list(solve(X,Y))
	l = sorted(l)
	print(l[Y-1])

解説

コメントにやたらと詳しい説明があるのであまり解説するところはありませんが解説します。

考え方

このプログラムは単純にビットで問題の通りに値を作り、その中でY番目の値を返すという単純な作りです。ポイントはいかに効率よく値を生成するかにあります。

値の生成

solve()が値を生成する関数です。
基本的な考え方は動的計画法を用いて値を効率よく生成します。
コメントのようにX=2の場合を例にします。

まず、条件を満たす最小の値を生成します。X=2なら0b11が最小です。
28行目の処理がこの処理で、X回左シフトしながら1を足し合わせているだけです。
そして、これを元の集合としてSetの変数sに設定します。

次に値を更新しながらsの要素数がYを超えるまで処理します。sの要素数がYを超えたらY番目の値は必ず集合に含まれるのでループを終えます。

37行目〜52行目のループが値を生成する部分です。
r0vは「元の値の右側に0を付加した値」、r1vは「元の値の右側に1を付加した値」、l0vは「元の値の左側に0を付加した値」、l1vは「元の値の左側に1を付加した値」を保持する変数です。
元が0b11ならr0v=0b110、r1v=0b111、l0v=0b011=0b11、l1v=0b111になります。

そして、それぞれの値に初期値より1ビット分1が多い値をマスクとして使い、生成した値をマスクしたものがマスクと同じかどうかで1の連続を検出します。実際は値を右に伸ばす場合と左に伸ばす場合があるので右に伸ばした時用にRMASK、左に伸ばした時用にLMASKを用意します。LMASKはループ回数(=伸ばした数だけ)左シフトします。
X=2の場合、ループの1回目はRMASK=0b111、LMASK=0b111、2回目はRMASK=0b111、LMASK=0b1110になります。
例えば0b1101をLMASKでマスクすると0b1100になります。これをLMASKと比較して同じにならない場合、そのループで左側に伸ばした場所の1の連続回数はX以下であることが保証されます。そして、ビット数が最小のものから順に1ビットずつ増やしながら値を作っているので右端か左端だけを毎回チェックすれば全体がチェックされたことになります。
このようにしてチェックし、1がXより多く連続していない値だけを集めます。
ちなみに、チェックの必要がないパターン(46行目〜48行目、など)もチェックしていますが、これはコード上の対称性を保ちたかったためです(処理を省くと後で見直した時に何でこの部分は処理していないのかを思い出さなければいけないため)。

ここでポイントなのが値の集合にSetを使っていることで、重複した値が自動的に無視されるます。

最終的にY番目の値を得る

最終的にY番目の値を印字するため、solve()の戻り値をListにし、ソートしてY-1番目の要素を取り出しています。リストは0始まりなので入力値より1小さい要素が答えになります。

雑感

なかなか面白い問題でした。
割とすんなり解けましたが、1の連続を検出する方法として値を伸ばす方向の先端部分X+1ビットだけをチェックすれば良いことに気づいたのが大きいです。こういうのを短時間で思いつけた時はかなり気分が良いものです。