Description
给你一个 N*N 的矩阵, 不用算矩阵乘法, 但是每次询问一个子矩形的第 K 小数.
Input
第一行两个数 N,Q, 表示矩阵大小和询问组数;
接下来 N 行 N 列一共 N*N 个数, 表示这个矩阵;
再接下来 Q 行每行 5 个数描述一个询问: x1,y1,x2,y2,k 表示找到以 (x1,y1) 为左上角, 以 (x2,y2) 为右下角的子矩形中的第 K 小数.
Output
对于每组询问输出第 K 小的数.
- Sample Input
- 2 2
- 2 1
- 3 4
- 1 2 1 2 1
- 1 1 2 2 3
- Sample Output
- 1
- 3
- HINT
矩阵中数字是 109 以内的非负整数;
20% 的数据: N<=100,Q<=1000;
40% 的数据: N<=300,Q<=10000;
60% 的数据: N<=400,Q<=30000;
100% 的数据: N<=500,Q<=60000.
可以离线, 把权值排序.
solve(b,e,l,r) 表示 b 到 e 的询问的答案在 l 到 r 范围的权值内.
答案 mid, 就插入前 mid 个数, 然后查询区间有多少个数, 可以用二维树状数组维护.
答案在左边就往左走, 在右边就减去左边的贡献往右走.
代码:
- #include <stdio.h>
- #include <string.h>
- #include <algorithm>
- using namespace std;
- #define N 60050
- int n,c[550][550],L[N],R[N],mid[N],ans[N];
- struct A {
- int v,x,y;
- }a[550*550];
- struct Q {
- int d,f,g,h,k,id;
- }q[N],t[N];
- bool cmp(const A &a,const A &b) {
- return a.v<b.v;
- }
- void fix(int x,int y,int v) {
- int i,j;
- for(i=x;i<=n;i+=i&(-i)) {
- for(j=y;j<=n;j+=j&(-j)) {
- c[i][j]+=v;
- }
- }
- }
- int inq(int x,int y) {
- int i,j,re=0;
- for(i=x;i;i-=i&(-i)) {
- for(j=y;j;j-=j&(-j)) {
- re+=c[i][j];
- }
- }
- return re;
- }
- int query(int x,int y,int x2,int y2) {
- x--; y--;
- return inq(x2,y2)-inq(x,y2)-inq(x2,y)+inq(x,y);
- }
- void solve(int b,int e,int l,int r) {
- int i;
- if(b>e) return ;
- if(l==r) {
- for(i=b;i<=e;i++) {
- ans[q[i].id]=a[l].v;
- }
- return ;
- }
- int mid=(l+r)>>1,lpos=b,rpos=e;
- for(i=l;i<=mid;i++) {
- fix(a[i].x,a[i].y,1);
- }
- for(i=b;i<=e;i++) {
- int sizls=query(q[i].d,q[i].f,q[i].g,q[i].h);
- if(sizls>=q[i].k) t[lpos++]=q[i];
- else q[i].k-=sizls,t[rpos--]=q[i];
- }
- for(i=b;i<=e;i++) q[i]=t[i];
- for(i=l;i<=mid;i++) {
- fix(a[i].x,a[i].y,-1);
- }
- solve(b,lpos-1,l,mid);
- solve(rpos+1,e,mid+1,r);
- }
- int main() {
- int m;
- scanf("%d%d",&n,&m);
- int i,tot=0,j;
- for(i=1;i<=n;i++) {
- for(j=1;j<=n;j++) {
- scanf("%d",&a[++tot].v);
- a[tot].x=i; a[tot].y=j;
- }
- }
- sort(a+1,a+n*n+1,cmp);
- for(i=1;i<=m;i++) {
- scanf("%d%d%d%d%d",&q[i].d,&q[i].f,&q[i].g,&q[i].h,&q[i].k);
- q[i].id=i;
- }
- solve(1,m,1,n*n);
- for(i=1;i<=m;i++) printf("%d\n",ans[i]);
- }
来源: http://www.bubuko.com/infodetail-2579034.html