(Analysis by Spencer Compton)

First, we make an observation about what the shortest distance to some K-tuple is. The distance to K-tuple (j1,j2,,jk) is d if in each graph Gi there is a walk of exactly d steps that ends at ji. To use this, we exploit more structure about what lengths of walks can end at a vertex. If there is a walk of length x in Gi that ends at ji, then we know there are also walks of length x,x+2,x+4, because one can repeatedly take one step away from ji and one step towards ji.

Moreover, if eveni[v] denotes the length of the shortest path of even length in Gi that ends at v, and oddi[v] for odd shortest path respectively, then valid distances for v are the union of {eveni[v],eveni[v]+2,eveni[v]+4,} and {oddi[v],oddi[v]+2,oddi[v]+4,}. (If there is no even path or odd path, ignore that corresponding set.) We can compute the quantities eveni[v] and oddi[v] by creating a duplicate graph (one representing odd and one representing even with edges going between the copies) and using a BFS.

Now, we want to use this structure to compute the sum of all distances. Core to this question, is knowing for some K-tuple what the minimum possible distance is. If we decide this distance will be even (and there exists an even path for each node in the tuple), then the distance is simply the maximum eveni[v] in the tuple. More concretely, we denote the sum of the distance of these tuples as compute_sum(Leven) where Leven contains a list of pairs containing each node's eveni[v] and corresponding graph (if it has an even path). Then, compute_sum calculates the sum, over all valid tuples, of the maximum value. (And analogously the same statement, if we decide this distance will be odd.)

But making such a decision for a tuple to be even or odd is difficult. Consider, instead, that we immediately calculate compute_sum(Leven)+compute_sum(Lodd). If an entire tuple could use even paths or odd paths, then we overcounted the answer for that tuple by exactly the larger quantity (e.g. if odd was the worse decision for that tuple, then the maximum oddi[v] in that tuple). Finally, we can correct this by subtracting compute_sum(Lmax), where Lmax denotes a list with max(eveni[v],oddi[v]) for corresponding nodes that have even paths and odd paths.

All that remains is how to calculate compute_sum(L) for some list L. One such way is computing the number of tuples less[x] where the maximum is x, then the answer is i(less[i]less[i1])×i. To calculate less[x], we note that this is equal to the product of cnti[x] over all i, where cnti[x] represents the number of nodes from graph i whose corresponding value is x. Directly computing this would be too slow, but we can optimize.

For each graph Gi, we compute cnti[x] for all x<2Ni. For all larger x, cnti[x]=cnti[2Ni1]. To hold this, we maintain a suffix product array (i.e. similar to a prefix sum array, but for suffixes and multiplication instead) suffix_prod and modify it such that all elements in the suffix starting with 2Ni will be multiplied by cnti[2Ni1]. For the 2Ni values of cnti, we can similarly have an array prefix_prod and multiply prefix_prod[x] by cnti[x]. Then, we can finally compute less[x] as prefix_prod[x]×suffix_prod[x]. This runs in linear time, so in total our algorithm runs in O(Ni+Mi) time.

It is also possible to calculate compute_sum(L) using a segment tree or modular inverses.

Spencer's code:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

ll mod = 1e9+7;

int k;
int n[50000];
int inf = 1e8;
ll compute_sum(vector<pair<int, int> > li){
	int maxn = 0;
	vector<ll> prefix_prod;
	vector<ll> suffix_prod;
	vector<ll> graphs[k];
	for(int i = 0; i<li.size(); i++){
		graphs[li[i].second].push_back(li[i].first);
	}
	for(int i = 0; i<k; i++){
		vector<ll> cnt(2*n[i]);
		maxn = max(maxn,2*n[i]);
		while(prefix_prod.size()<maxn){
			prefix_prod.push_back(1);
		}
		while(suffix_prod.size()<=maxn){
			suffix_prod.push_back(1);
		}
		for(int j = 0; j<graphs[i].size(); j++){
			cnt[graphs[i][j]]++;
		}
		for(int j = 0; j<2*n[i]; j++){
			if(j>0){
				cnt[j] += cnt[j-1];
			}
			prefix_prod[j] *= cnt[j];
			prefix_prod[j] %= mod;
		}
		suffix_prod[2*n[i]] *= cnt[2*n[i]-1];
		suffix_prod[2*n[i]] %= mod;
	}
	for(int i = 1; i<suffix_prod.size(); i++){
		suffix_prod[i] *= suffix_prod[i-1];
		suffix_prod[i] %= mod;
	}
	ll ans = 0LL;
	for(int i = 1; i<maxn; i++){
		ll cur_num = (prefix_prod[i]*suffix_prod[i])-(prefix_prod[i-1]*suffix_prod[i-1]);
		cur_num %= mod;
		ans += cur_num * (ll)i;
		ans %= mod;
	}
	if(ans<0LL){
		ans += mod;
	}
	return ans;
}
int main(){
	ios_base::sync_with_stdio(false);
	cin.tie(0);
	vector<pair<int, int> > evens;
	vector<pair<int, int> > odds;
	vector<pair<int, int> > both;
	cin >> k;
	for(int i = 0; i<k; i++){
		int m;
		cin >> n[i] >> m;
		vector<int> adj[2*n[i]];
		for(int j = 0; j<m; j++){
			int a, b;
			cin >> a >> b;
			a--;
			b--;
			adj[a].push_back(n[i]+b);
			adj[b].push_back(n[i]+a);
			adj[n[i]+a].push_back(b);
			adj[n[i]+b].push_back(a);
		}
		vector<int> dist(2*n[i], inf);
		vector<int> li;
		dist[0] = 0;
		li.push_back(0);
		for(int j = 0; j<li.size(); j++){
			int now = li[j];
			for(int a = 0; a<adj[now].size(); a++){
				int to = adj[now][a];
				if(dist[to]==inf){
					dist[to] = dist[now]+1;
					li.push_back(to);
				}
			}
		}
		for(int j = 0; j<n[i]; j++){
			if(dist[j]<inf){
				evens.push_back(make_pair(dist[j],i));
			}
			if(dist[j+n[i]]<inf){
				odds.push_back(make_pair(dist[j+n[i]],i));
			}
			if(max(dist[j],dist[j+n[i]])<inf){
				both.push_back(make_pair(max(dist[j],dist[j+n[i]]),i));
			}
		}
	}
	ll ans = compute_sum(evens)+compute_sum(odds)-compute_sum(both);
	ans %= mod;
	if(ans<0LL){
		ans += mod;
	}
	cout << ans << "\n";
}

Danny Mittal's code (with modular inverse):

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
 
public class SingleSourceShortestPath {
    public static final long MOD = 1000000007L;
 
    public static long inverse(long base) {
        int exponent = (int) MOD - 2;
        long res = 1;
        while (exponent != 0) {
            if (exponent % 2 == 1) {
                res *= base;
                res %= MOD;
            }
            exponent /= 2;
            base *= base;
            base %= MOD;
        }
        return res;
    }
 
    public static void main(String[] args) throws IOException {
        BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
        long[] amts = new long[200000];
        long[] amts2 = new long[200000];
        boolean[] amtsZero = new boolean[200000];
        boolean[] amts2Zero = new boolean[200000];
        Arrays.fill(amts, 1);
        Arrays.fill(amts2, 1);
        int k = Integer.parseInt(in.readLine());
        boolean anyBipartite = false;
        for (int g = 0; g < k; g++) {
            in.readLine();
            StringTokenizer tokenizer = new StringTokenizer(in.readLine());
            int n = Integer.parseInt(tokenizer.nextToken());
            int m = Integer.parseInt(tokenizer.nextToken());
            List<Integer>[] adj = new List[(2 * n) + 1];
            for (int a = 1; a <= 2 * n; a++) {
                adj[a] = new ArrayList<>();
            }
            for (int j = 1; j <= m; j++) {
                tokenizer = new StringTokenizer(in.readLine());
                int a = Integer.parseInt(tokenizer.nextToken());
                int b = Integer.parseInt(tokenizer.nextToken());
                adj[a].add(n + b);
                adj[n + b].add(a);
                adj[n + a].add(b);
                adj[b].add(n + a);
            }
            int[] dist = new int[(2 * n) + 1];
            Arrays.fill(dist, -1);
            dist[1] = 0;
            LinkedList<Integer> q = new LinkedList<>();
            q.add(1);
            while (!q.isEmpty()) {
                int a = q.remove();
                for (int b : adj[a]) {
                    if (dist[b] == -1) {
                        dist[b] = dist[a] + 1;
                        q.add(b);
                    }
                }
            }
            if (dist[n + 1] == -1) {
                anyBipartite = true;
            }
            long[] freq = new long[2 * n];
            long[] freq2 = new long[2 * n];
            for (int a = 1; a <= n; a++) {
                if (dist[a] != -1) {
                    freq[dist[a]]++;
                }
                if (dist[n + a] != -1) {
                    freq[dist[n + a]]++;
                }
                if (dist[a] != -1 && dist[n + a] != -1) {
                    freq2[Math.max(dist[a], dist[n + a])]++;
                }
            }
            for (int d = 2; d < 2 * n; d++) {
                freq[d] += freq[d - 2];
                freq2[d] += freq2[d - 1];
            }
            for (int d = 0; d < 2 * n; d++) {
                if (freq[d] == 0L) {
                    amtsZero[d] = true;
                } else {
                    amts[d] *= freq[d];
                    amts[d] %= MOD;
                    if (d >= 2 && freq[d - 2] != 0L) {
                        amts[d] *= inverse(freq[d - 2]);
                        amts[d] %= MOD;
                    }
                }
                if (freq2[d] == 0L) {
                    amts2Zero[d] = true;
                } else {
                    amts2[d] *= freq2[d];
                    amts2[d] %= MOD;
                    if (d >= 1 && freq2[d - 1] != 0L) {
                        amts2[d] *= inverse(freq2[d - 1]);
                        amts2[d] %= MOD;
                    }
                }
            }
        }
        for (int d = 2; d < 200000; d++) {
            amts[d] *= amts[d - 2];
            amts[d] %= MOD;
            amts2[d] *= amts2[d - 1];
            amts2[d] %= MOD;
        }
        for (int d = 0; d < 200000; d++) {
            if (amtsZero[d]) {
                amts[d] = 0;
            }
            if (amts2Zero[d]) {
                amts2[d] = 0;
            }
        }
        if (anyBipartite) {
            Arrays.fill(amts2, 0);
        }
        long answer = 0;
        for (int d = 0; d < 200000; d++) {
            long dl = d;
            answer += dl * amts[d];
            answer -= dl * amts2[d];
            if (d >= 2) {
                answer -= dl * amts[d - 2];
                answer += dl * amts2[d - 1];
            }
            answer %= MOD;
        }
        answer += MOD;
        answer %= MOD;
        System.out.println(answer);
    }
}