CCPC2020 威海C 樹形dp

2020-10-26 12:01:30

C Rencontre

在這裡插入圖片描述

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef long double LD;
typedef pair<int,LD> pill;
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=a;i>=b;i--)
#define fi first
#define se second
#define pb push_back
const int maxn=2e5+9;

vector<pill>v[maxn];
int sta[maxn];

LD all[4],sonl[maxn][4],fal[maxn][4],num[maxn][4];

void init(int p,int fa){
    rep(i,0,2){
        sonl[p][i]=fal[p][i]=num[p][i]=0;
        if(sta[p]&(1<<i))num[p][i]++;
    }
    for(int _=0;_<v[p].size();_++){
        int u=v[p][_].fi;
        if(u==fa)continue;
        init(u,p);
        LD len=v[p][_].se;
        rep(i,0,2){
            sonl[p][i]+=sonl[u][i];
            sonl[p][i]+=num[u][i]*len;
            num[p][i]+=num[u][i];
        }
    }
}
void init2(int p,int fa){
    for(int _=0;_<v[p].size();_++){
        int u=v[p][_].fi;
        if(u==fa)continue;
        LD len=v[p][_].se;
        rep(i,0,2){
            fal[u][i]=sonl[p][i]-sonl[u][i]-num[u][i]*len;
            fal[u][i]+=(num[p][i]-num[u][i])*len;
            fal[u][i]+=fal[p][i]+(all[i]-num[p][i])*len;
        }
        init2(u,p);
    }
}

LD ans[maxn][4];

void dfs(int p,int fa){
    LD tmpsan[4][4];
    rep(i,0,2){
        rep(j,0,2){
            tmpsan[i][j]=(all[i]-num[p][i])*(all[j]-num[p][j]);
        }
    }
    for(int _=0;_<v[p].size();_++){
        int u=v[p][_].fi;
        if(u==fa)continue;
        rep(i,0,2){
            rep(j,0,2){
                tmpsan[i][j]+=num[u][i]*num[u][j];
            }
        }
    }
    for(int _=0;_<v[p].size();_++){
        int u=v[p][_].fi;
        if(u==fa)continue;
        LD len=v[p][_].se;
        rep(i,0,2){
            int j=0;
            while(j==i)j++;
            int k=1;
            while(k==i||k==j)k++;

            LD w=sonl[u][i]+num[u][i]*len;
            LD q=(all[j]-num[u][j])*(all[k]-num[u][k])-(tmpsan[j][k]-num[u][j]*num[u][k]);
            ans[p][i]+=w*q;
        }
    }

    rep(i,0,2){
        int j=0;
        while(j==i)j++;
        int k=1;
        while(k==i||k==j)k++;

        LD w=fal[p][i];
        LD q=(num[p][j]*num[p][k])-(tmpsan[j][k]-(all[j]-num[p][j])*(all[k]-num[p][k]));
        ans[p][i]+=w*q;
    }

    for(int _=0;_<v[p].size();_++){
        int u=v[p][_].fi;
        if(u==fa)continue;

        dfs(u,p);
    }
}

int main()
{
    ios::sync_with_stdio(false);
    int n;
    cin>>n;
    rep(i,1,n-1){
        int a,b,c;
        cin>>a>>b>>c;
        v[a].pb({b,(LD)c});
        v[b].pb({a,(LD)c});
    }
    rep(i,0,2){
        int tmp;
        cin>>tmp;
        all[i]=(LD)tmp;
        rep(j,1,tmp){
            int p;
            cin>>p;
            sta[p]|=1<<i;
        }
    }
    init(1,0);
    init2(1,0);
    dfs(1,0);

//    rep(i,1,n){
//        rep(j,0,2){
//            printf("%d,%d:  fal:%lld sonl:%lld num:%lld ans:%lld\n",i,j,fal[i][j],sonl[i][j],num[i][j],ans[i][j]);
//        }
//    }

    LD A=0;
    LD base=(LD)(all[0])*(LD)(all[1])*(LD)(all[2]);
    rep(i,1,n){
        rep(j,0,2)
        A+=(LD)(ans[i][j])/base;
    }
    printf("%.10f\n",(double)A);
    return 0;
}