백준: Class 6 - 1761(LCA 알고리즘)

2 분 소요


</br> LCA 알고리즘 정리하기

lca 알고리즘은 트리에서 두 노드의 최소 공통 조상을 구하는 알고리즘이다.
</br>

LCA

#include <bits/stdc++.h>
#define P pair<int, int>

using namespace std;

vector<P> tree[40001];
int depth[40001];
int ac[40001][16];
int max_level;

void getTree(int now, int pre){
    depth[now] = depth[pre] + 1;
    ac[now][0] = pre;

    for(int i = 1; i <= max_level; i++){
        int tmp = ac[now][i - 1];
        ac[now][i] = ac[tmp][i - 1];
    }

    for(int i = 0; i < tree[now].size(); i++){
        int nx = tree[now][i].second;
        if( nx != pre ) getTree(nx, now);
    }
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    max_level = (int)floor(log2(40000));

    for(int i = 0; i < n-1; i++){
        int a, b, c;
        cin >> a >> b >> c;
        tree[a].push_back(make_pair(c, b));
        tree[b].push_back(make_pair(c, a));
    }

    depth[0] = -1;

    getTree(1, 0);

    int m;
    cin >> m;

    while( m-- ){
        int a, b;
        cin >> a >> b;
        if( depth[a] != depth[b] ){
            if( depth[a] > depth[b] ) swap(a, b);
            for(int i = max_level; i >= 0; i--){
                if( depth[a] <= depth[ac[b][i]] ) b = ac[b][i];
            }
        }
        int lca = a;
        if( a != b ){
            for(int i = max_level; i >= 0; i--){
                if( ac[a][i] != ac[b][i] ){
                    a = ac[a][i];
                    b = ac[b][i];
                }
                lca = ac[a][i];
            }
        }
        cout << lca << endl;
    }

}

lca 알고리즘의 기본적인 구조다.
</br>

1761: 정점들의 거리

https://www.acmicpc.net/problem/1761

#include <bits/stdc++.h>
#define P pair<int, int>

using namespace std;

vector<P> tree[40001];
int depth[40001];
int ac[40001][16];
int dist[40001];
int max_level;

void getTree(int now, int pre, int d){
    depth[now] = depth[pre] + 1;
    ac[now][0] = pre;
    dist[now] = d;

    for(int i = 1; i <= max_level; i++){
        int tmp = ac[now][i - 1];
        ac[now][i] = ac[tmp][i - 1];
    }

    for(int i = 0; i < tree[now].size(); i++){
        int nd = tree[now][i].first;
        int nx = tree[now][i].second;
        if( nx != pre ) getTree(nx, now, d + nd);
    }
}

int getlca(int a, int b){
    if( depth[a] != depth[b] ){
        if( depth[a] > depth[b] ) swap(a, b);
        for(int i = max_level; i >= 0; i--){
            if( depth[a] <= depth[ac[b][i]] ){
                b = ac[b][i];
            }
        }
    }
    int ret = a;
    if( a != b ){
        for(int i = max_level; i >= 0; i--){
            if( ac[a][i] != ac[b][i] ){
                a = ac[a][i];
                b = ac[b][i];
            }
            ret = ac[a][i];
        }
    }
    return ret;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    max_level = (int)floor(log2(40000));

    for(int i = 0; i < n-1; i++){
        int a, b, c;
        cin >> a >> b >> c;
        tree[a].push_back(make_pair(c, b));
        tree[b].push_back(make_pair(c, a));
    }

    depth[0] = -1;

    getTree(1, 0, 0);

    int m;
    cin >> m;

    while( m-- ){
        int a, b;
        cin >> a >> b;
        int lca = getlca(a, b);
        cout << dist[a] + dist[b] - 2*dist[lca] << '\n';
    }

}

두 정점 사이 거리를 구해야 한다.
1번을 루트로 트리를 만들면, 1번과의 거리를 dist[] 배열에 저장한다.
(1번부터 a까지의 거리) + (1번부터 b까지의 거리) - 2 * (1번부터 lca까지의 거리)로 두 정점 사이의 거리를 구할 수 있다.
</br>


이제 방학 동안 쌓아 놨던 여분 포스트가 없다
</br>

댓글남기기