-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloter_multiRegions_old.py
122 lines (93 loc) · 4.44 KB
/
loter_multiRegions_old.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 12 21:47:54 2021
@author: Yudongcai
@Email: [email protected]
"""
import gzip
import typer
import allel
import numpy as np
import pandas as pd
from collections import defaultdict
import loter.locanc.local_ancestry as lc
def vcf2npy(vcffile, samples, region=None, return_more=False):
fields = ['samples', 'calldata/GT']
if return_more:
fields=['samples', 'calldata/GT', 'variants/CHROM', 'variants/POS']
callset = allel.read_vcf(vcffile, samples=samples, region=region,
fields=fields)
haplotypes_1 = callset['calldata/GT'][:,:,0]
haplotypes_2 = callset['calldata/GT'][:,:,1]
m, n = haplotypes_1.shape
mat_haplo = np.empty((2*n, m))
mat_haplo[::2] = haplotypes_1.T
mat_haplo[1::2] = haplotypes_2.T
keep_samples = callset['samples']
if return_more:
return mat_haplo.astype(np.uint8), keep_samples, callset['variants/CHROM'], callset['variants/POS']
else:
return mat_haplo.astype(np.uint8)
def load_group(groupfile):
sample2group = {}
group2samples = defaultdict(list)
with open(groupfile) as f:
for line in f:
sample, group = line.strip().split()
sample2group[sample] = group
group2samples[group].append(sample)
return sample2group, group2samples
def load_regions(regionfile):
regions = []
with open(regionfile) as f:
for line in f:
tline = line.strip().split()
regions.append(f'{tline[0]}:{tline[1]}-{tline[2]}')
return regions
def main(vcffile: str = typer.Argument(..., help="总vcf文件"),
outfile: str = typer.Argument(..., help='输出文件名(gzipped), XX.tsv.gz'),
groupfile: str = typer.Option(..., help='样本分群信息,两列,一列vcf中的样本ID,一列对应的群体ID'),
regionfile: str = typer.Option(..., help='区域文件,三列,chrom\\tstart\\tend'),
refpops: str = typer.Option(..., help='参考群体ID,此ID须存在于groupfile的第二列中。多个群体用,分割'),
querypops: str = typer.Option(..., help='待检测群体ID,此ID须存在于groupfile的第二列中。多个群体用,分割'),
threads: int = typer.Option(1, help='使用的线程数'),
nbags: int = typer.Option(20, help='number of resampling in the bagging')):
sample2group, group2samples = load_group(groupfile)
querysamples = []
for group in querypops.split(','):
querysamples.extend(group2samples[group])
hapIDs = [[f'{i}_1', f'{i}_2'] for i in querysamples]
hapIDs = [i for j in hapIDs for i in j]
hapID2group = {x: sample2group[x[:-2]] for x in hapIDs}
cols = ['chrom', 'pos', 'region', 'hapID', 'group', 'sourceIndex', 'nb_bagging']
with gzip.open(outfile, 'wb') as f:
header = '\t'.join(cols) + '\n'
f.write(header.encode())
# 分区域计算
regions = load_regions(regionfile)
for region in regions:
print(region)
H_query, query_samples, chrs, sites = vcf2npy(vcffile, querysamples, region=region, return_more=True)
H_refs = []
for group in refpops.split(','):
H_ref = vcf2npy(vcffile, group2samples[group], region=region)
H_refs.append(H_ref)
print(f'{len(sites)} sites loaded.')
res_loter = lc.loter_local_ancestry(H_refs, H_query, num_threads=threads, nb_bagging=nbags)
df = pd.DataFrame(res_loter[0].T, columns=hapIDs)
df['chrom'] = chrs
df['pos'] = sites
mdf = df.melt(id_vars=['chrom', 'pos'], value_vars=hapIDs, var_name='hapID', value_name='sourceIndex')
mdf['group'] = mdf['hapID'].map(hapID2group)
bdf = pd.DataFrame(res_loter[1].T, columns=hapIDs)
bdf['chrom'] = chrs
bdf['pos'] = sites
mbdf = bdf.melt(id_vars=['chrom', 'pos'], value_vars=hapIDs, var_name='hapID', value_name='nb_bagging')
mbdf['group'] = mbdf['hapID'].map(hapID2group)
mdf = pd.merge(mdf, mbdf, on=['chrom', 'pos', 'hapID', 'group'], how='inner')
mdf['region'] = region # 这个很关键,因为是分区域算的,区域之间不连续,后续位点合并为区间只能在一个区域内进行
mdf[cols].to_csv(outfile, sep='\t', index=False, header=None, mode='a', compression='gzip')
if __name__ == '__main__':
typer.run(main)
# 添加区间去冗余的参数
# 去掉无变异位点的参数