bzoj3160 萬徑人蹤滅

題意:給定只有ab的字符串,求其中不連續非空迴文子序列的個數。ide

解:用全部迴文子序列減去迴文子串。spa

容易想到枚舉中線。code

設f[i]表示以i爲中線,迴文字符的個數。那麼迴文子序列就是∑(2f[i] - 1)blog

怎麼求f[i]呢?卷積!字符串

咱們考慮把i變成中線 * 2,那麼f[i] = ( ∑(s[j] == s[i - j])  + 1) / 2get

令a = 1,b = 0,那麼卷積得出的就是全部a的迴文字符數。同理能夠求得全部b的迴文字符數。string

而後用迴文自動機求一遍迴文子串的數目,相減便可。it

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 #include <cmath>
  5 
  6 typedef long long LL;
  7 const int N = 100010;
  8 const LL MO = 1e9 + 7;
  9 const double pi = 3.1415926535897932384626;
 10 
 11 struct cp {
 12     double x, y;
 13     cp(double X = 0, double Y = 0) {
 14         x = X;
 15         y = Y;
 16     }
 17     inline cp operator +(const cp &w) const {
 18         return cp(x + w.x, y + w.y);
 19     }
 20     inline cp operator -(const cp &w) const {
 21         return cp(x - w.x, y - w.y);
 22     }
 23     inline cp operator *(const cp &w) const {
 24         return cp(x * w.x - y * w.y, x * w.y + y * w.x);
 25     }
 26 }a[N << 2];
 27 
 28 int r[N << 2], f[N << 1];
 29 char s[N];
 30 
 31 inline LL qpow(LL a, int b) {
 32     LL ans = 1;
 33     while(b) {
 34         if(b & 1) {
 35             ans = ans * a % MO;
 36         }
 37         a = a * a % MO;
 38         b = b >> 1;
 39     }
 40     return ans;
 41 }
 42 
 43 inline void FFT(int n, cp *a, int f) {
 44     for(int i = 0; i < n; i++) {
 45         if(i < r[i]) {
 46             std::swap(a[i], a[r[i]]);
 47         }
 48     }
 49 
 50     for(int len = 1; len < n; len <<= 1) {
 51         cp Wn(cos(pi / len), f * sin(pi / len));
 52         for(int i = 0; i < n; i += (len << 1)) {
 53             cp w(1, 0);
 54             for(int j = 0; j < len; j++) {
 55                 cp t = a[i + len + j] * w;
 56                 a[i + len + j] = a[i + j] - t;
 57                 a[i + j] = a[i + j] + t;
 58                 w = w * Wn;
 59             }
 60         }
 61     }
 62 
 63     if(f == -1) {
 64         for(int i = 0; i <= n; i++) {
 65             a[i].x /= n;
 66         }
 67     }
 68     return;
 69 }
 70 
 71 namespace pam {
 72     int tr[N][26], fail[N], cnt[N], len[N], last, tot;
 73     inline void init() {
 74         len[1] = -1;
 75         fail[0] = fail[1] = 1;
 76         tot = last = 1;
 77         return;
 78     }
 79     inline int getfail(int d, int x) {
 80         while(s[d - len[x] - 1] != s[d]) {
 81             x = fail[x];
 82         }
 83         return x;
 84     }
 85     inline void insert(int d) {
 86         int f = s[d] - 'a';
 87         int p = getfail(d, last);
 88         if(!tr[p][f]) {
 89             ++tot;
 90             len[tot] = len[p] + 2;
 91             fail[tot] = tr[getfail(d, fail[p])][f];
 92             tr[p][f] = tot;
 93         }
 94         last = tr[p][f];
 95         cnt[last]++;
 96         return;
 97     }
 98     inline LL count() {
 99         LL ans = 0;
100         for(int i = tot; i >= 2; i--) {
101             ans = (ans + cnt[i]) % MO;
102             (cnt[fail[i]] += cnt[i]) %= MO;
103         }
104         return ans;
105     }
106 }
107 
108 int main() {
109     scanf("%s", s);
110     int n = strlen(s) - 1;
111     for(int i = 0; i <= n; i++) {
112         a[i].x = (s[i] == 'a');
113     }
114     int len = 2, lm = 1;
115     while(len <= n + n) {
116         len <<= 1;
117         lm++;
118     }
119     for(int i = 1; i <= len; i++) {
120         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
121     }
122 
123     FFT(len, a, 1);
124     for(int i = 0; i <= len; i++) {
125         a[i] = a[i] * a[i];
126     }
127     FFT(len, a, -1);
128     for(int i = 0; i <= n + n; i++) {
129         f[i] = ((int)(a[i].x + 0.5) + 1) >> 1;
130     }
131 
132     for(int i = 0; i <= n; i++) {
133         a[i].x = (s[i] == 'b');
134         a[i].y = 0;
135     }
136     for(int i = n + 1; i <= len; i++) {
137         a[i] = cp(0, 0);
138     }
139     FFT(len, a, 1);
140     for(int i = 0; i <= len; i++) {
141         a[i] = a[i] * a[i];
142     }
143     FFT(len, a, -1);
144     for(int i = 0; i <= n + n; i++) {
145         f[i] += ((int)(a[i].x + 0.5) + 1) >> 1;
146     }
147 
148     LL ans = 0;
149     for(int i = 0; i <= n + n; i++) {
150         ans = (ans + qpow(2ll, f[i]) - 1) % MO;
151     }
152     pam::init();
153     for(int i = 0; i <= n; i++) {
154         pam::insert(i);
155     }
156     ans = (ans - pam::count() + MO) % MO;
157     printf("%lld", ans);
158     return 0;
159 }
AC代碼