枚举起点做 spfa, 然后一条边在最短路上的条件是 dis[e[i].to]==dis[u]+e[i].va, 所以每次 spfa 完之后, dfs 出 a[i]表示经过 i 点的最短路的起点数, b[i]表示经过 i 点的最短路的终点数, 一条边 (u,v) 在当前起点下的答案就是 a[u]*b[v], 最终答案是总和
因为最短路构成一个 DAG, 所以 a 是按照类似拓扑序的东西来 dfs 的
- #include<iostream>
- #include<cstdio>
- #include<queue>
- #include<cstring>
- using namespace std;
- const int N=10005,mod=1e9+7;
- int n,m,h[N],cnt,dis[N],f[N],a[N],b[N],pr[N];
- bool v[N];
- struct qwe
- {
- int ne,no,to,va;
- }e[N];
- int read()
- {
- int r=0,f=1;
- char p=getchar();
- while(p>'9'||p<'0')
- {
- if(p=='-')
- f=-1;
- p=getchar();
- }
- while(p>='0'&&p<='9')
- {
- r=r*10+p-48;
- p=getchar();
- }
- return r*f;
- }
- void add(int u,int v,int w)
- {
- cnt++;
- e[cnt].ne=h[u];
- e[cnt].no=u;
- e[cnt].to=v;
- e[cnt].va=w;
- h[u]=cnt;
- }
- void spfa(int s)
- {
- deque<int>q;
- for(int i=1;i<=n;i++)
- dis[i]=1e9;
- memset(v,0,sizeof(v));
- q.push_back(s);
- dis[s]=0;
- v[s]=1;
- while(!q.empty())
- {
- int u=q.front();
- q.pop_front();
- v[u]=0;
- for(int i=h[u];i;i=e[i].ne)
- if(dis[e[i].to]>dis[u]+e[i].va)
- {
- dis[e[i].to]=dis[u]+e[i].va;
- if(!v[e[i].to])
- {
- v[e[i].to]=1;
- if(q.empty()||dis[q.front()]<dis[e[i].to])
- q.push_back(e[i].to);
- else
- q.push_front(e[i].to);
- }
- }
- }
- }
- void dfs(int u)
- {
- v[u]=1;
- for(int i=h[u];i;i=e[i].ne)
- if(dis[e[i].to]==dis[u]+e[i].va)
- {
- pr[e[i].to]++;
- if(!v[e[i].to])
- dfs(e[i].to);
- }
- }
- void dfsa(int u)
- {
- for(int i=h[u];i;i=e[i].ne)
- if(dis[e[i].to]==dis[u]+e[i].va)
- {
- v[i]=1;
- a[e[i].to]=(a[e[i].to]+a[u])%mod;
- if(!(--pr[e[i].to]))
- dfsa(e[i].to);
- }
- }
- void dfsb(int u)
- {
- b[u]=1;
- for(int i=h[u];i;i=e[i].ne)
- if(dis[e[i].to]==dis[u]+e[i].va)
- {
- if(!b[e[i].to])
- dfsb(e[i].to);
- b[u]=(b[u]+b[e[i].to])%mod;
- }
- }
- int main()
- {
- n=read(),m=read();
- for(int i=1;i<=m;i++)
- {
- int x=read(),y=read(),z=read();
- add(x,y,z);
- }
- for(int i=1;i<=n;i++)
- {
- spfa(i);
- memset(v,0,sizeof(v));
- dfs(i);
- memset(v,0,sizeof(v));
- memset(a,0,sizeof(a));
- memset(b,0,sizeof(b));
- a[i]=1;
- dfsa(i);
- dfsb(i);
- for(int j=1;j<=cnt;j++)
- if(v[j])
- f[j]=(f[j]+1ll*a[e[j].no]*b[e[j].to]%mod)%mod;
- }
- for(int i=1;i<=cnt;i++)
- printf("%d\n",f[i]);
- return 0;
- }
来源: http://www.bubuko.com/infodetail-2760056.html