輪郭をなぞるだけのブログ

浅学菲才のためにおそらく嘘も多い

ARC036 C - 偶然ジェネレータ

atcoder.jp

問題概要

'0', '1', '?' からなる長さ N の文字列が与えられる。
'?' には、'0' または '1' が入れられる。
それを埋めた '0' と '1' からなる文字列の任意の連続部分列の中で '0' と '1' の個数の差が K 以下となる文字列の個数を 1000000007 で割った余りを求める。

制約

1 \leq N \leq 300
1 \leq K \leq N

(考察)

制約も小さいし、雰囲気は DP かと思ったけど、定義できないね。
最近見てる問題難しくて困った。

dp[ i ][ j ][ k ] := i 桁目から左方向に見たときの ('0' の個数 - '1' の個数) の最大値が j, 最小値が -k であるときの通り数
と定義する。
j が負になるときは、代わりに 0 に遷移させる。
k の場合も同様である。
この定義、正直あまり理解しきれていない。

自分の中では、現在地からどれだけ上または下に戻れるかという状態を持っていると解釈している。
例えば、文字列が "000110111" であったとき、お絵かきをしてみると、下図のようになる。
f:id:OutLine:20200430214558p:plain

'1' のとき上方向、'0' のとき下方向に進んでいて、jk の遷移も書いてある。
(このお絵かきするなら、jk の定義は逆の方がよかったな…)
例えば、最終地点を見ると (j, k) = (0, 4) となっているが、これはその地点から左方向を見たときに、上方向には 0、下方向には 4 だけ進めるということを示している。

図を見ると分かるように、最終地点での j + k が連続部分列で取り得る最大の差となるので、それが K 以下となるものの合計を出力すればよい。

いやちょっとみんな纏め方うますぎるな、困る。

コード

modint の定義は長くなるので省略。
遷移では、K を超えてもそのまま計算させている。
超えるときは遷移させないのが普通な気がするが、j + k は短くはならなさそうなので、これでも多分大丈夫である。

#include <iostream>
#include <vector>
#include <string>
using namespace std;

constexpr int mod = 1000000007;
struct modint{ /* modint の定義 */ };

// dp[i][j][k] := i 桁目から左方向を見たとき、('0'の個数 - '1'の個数)
//                の最大値が j, 最小値が -k である通り数
// y 軸で捉えると、i 桁目の現在の座標 y_{i} から、正方向に j, 負方向に k だけ進めるということを表している。 
// j + k が連続部分列最大の差になるので、j + k <= K を満たすものの合計が答えとなる
modint dp[305][305][305];

int main(){
    int n, K;
    string s;
    cin >> n >> K >> s;

    dp[0][0][0] = 1;
    for(int i = 0; i < n; ++i){
        for(int j = 0; j <= i; ++j){
            for(int k = 0; j + k <= i; ++k){
                if(dp[i][j][k] == 0)    continue;
                
                if(s[i] == '0'){
                    dp[i + 1][j + 1][max(0, k - 1)] += dp[i][j][k];
                }
                else if(s[i] == '1'){
                    dp[i + 1][max(0, j - 1)][k + 1] += dp[i][j][k];
                }
                else{   // s[i] == '?'
                    dp[i + 1][j + 1][max(0, k - 1)] += dp[i][j][k];
                    dp[i + 1][max(0, j - 1)][k + 1] += dp[i][j][k];
                }
            }
        }
    }

    modint ans = 0;
    for(int i = 0; i <= K; ++i){
        for(int j = 0; i + j <= K; ++j){
            ans += dp[n][i][j];
        }
    }
    cout << ans << endl;
}
雑記

5 月、どうして...