-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathplot_jointScatter.py
132 lines (103 loc) · 4.64 KB
/
plot_jointScatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 2 15:28:26 2018
@author: YudongCai
@Email: [email protected]
last change: Fri Jul 30 11:23:36 2021
"""
import click
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.pyplot as plt
def load_infile(infile, x_col, y_col):
"""
infile contain several chromosomes
"""
df = pd.read_csv(infile,
sep='\t',
usecols = [x_col, y_col],
dtype={
x_col: float,
y_col: float})
df.dropna(inplace=True)
return df
def plot(df, x_col, y_col, cutoff_x, cutoff_y, highlightdata, markersize, outfile):
fig, axScatter = plt.subplots(figsize=(5.5, 5.5))
# scatter plot
axScatter.scatter(df[x_col], df[y_col],
s=markersize,
color='k')
# create new axes on the right and on the top of the current axes
# The first argument of the new_vertical(new_horizontal) method is
# the height (width) of the axes to be created in inches.
divider = make_axes_locatable(axScatter)
axHistx = divider.append_axes("top", 0.8, pad=0.1, sharex=axScatter)
axHisty = divider.append_axes("right", 0.8, pad=0.1, sharey=axScatter)
# make some labels invisible
_ = plt.setp(axHistx.get_xticklabels() + axHisty.get_yticklabels(),
visible=False)
# make some spines invisible
axHistx.spines['right'].set_visible(False)
axHistx.spines['top'].set_visible(False)
axHisty.spines['right'].set_visible(False)
axHisty.spines['top'].set_visible(False)
axScatter.spines['right'].set_visible(False)
axScatter.spines['top'].set_visible(False)
# hist plot
_ = axHistx.hist(df[x_col], bins=100, color='#CCC0DA')
_ = axHisty.hist(df[y_col], bins=100, orientation='horizontal', color='#D8E4BC')
# fix xlim ylim
scatter_xlim = axScatter.get_xlim()
scatter_ylim = axScatter.get_ylim()
axScatter.set_ylim(scatter_ylim)
axScatter.set_xlim(scatter_xlim)
histx_xlim = axHistx.get_xlim()
histx_ylim = axHistx.get_ylim()
axHistx.set_xlim(histx_xlim)
axHistx.set_ylim(histx_ylim)
histy_xlim = axHisty.get_xlim()
histy_ylim = axHisty.get_ylim()
axHisty.set_xlim(histy_xlim)
axHisty.set_ylim(histy_ylim)
# cutoff
cutcolor = '#808080'
cutstyle = 'dashed'
axScatter.vlines(cutoff_x, scatter_ylim[0], scatter_ylim[1], linestyle=cutstyle, color=cutcolor)
axHistx.vlines(cutoff_x, histx_ylim[0], histx_ylim[1], linestyle=cutstyle, color=cutcolor)
axScatter.hlines(cutoff_y, scatter_xlim[0], scatter_xlim[1], linestyle=cutstyle, color=cutcolor)
axHisty.hlines(cutoff_y, histy_xlim[0], histy_xlim[1], linestyle=cutstyle, color=cutcolor)
# 高亮过cutoff的点
tdf = df.loc[(df[x_col]>=cutoff_x) & (df[y_col]>=cutoff_y), :]
axScatter.scatter(tdf[x_col], tdf[y_col],
s=markersize,
color='#E74C3C')
if highlightdata:
tdf = pd.read_csv(highlightdata,
sep='\t',
usecols = [x_col, y_col],
dtype={
x_col: float,
y_col: float})
tdf.dropna(inplace=True)
tdf = tdf.loc[(tdf[x_col]>=cutoff_x) & (tdf[y_col]>=cutoff_y), :]
axScatter.scatter(tdf[x_col], tdf[y_col],
s=markersize,
color='#4878D0')
plt.savefig(outfile, dpi=300, transparent=True)
@click.command()
@click.option('--infile', help='tsv文件,包含header')
@click.option('--x-col', help='x轴值列名')
@click.option('--y-col', help='y轴值列名')
@click.option('--cutoff-x', help='highlight cutoff in x axis', type=float)
@click.option('--cutoff-y', help='hightlight cutoff in y axis', type=float)
@click.option('--highlightdata', help='额外这些位点进行highlight(过了阈值线的),数据输入格式同infile,至少包括xcol和ycol对应的两列', type=str, default=None)
@click.option('--markersize', default=3, help='散点大小, default is 3', type=float)
@click.option('--outfile', help='输出文件,根据拓展名判断输出格式')
def main(infile, x_col, y_col, cutoff_x, cutoff_y, highlightdata, markersize, outfile):
df = load_infile(infile, x_col, y_col)
plot(df, x_col, y_col, cutoff_x, cutoff_y, highlightdata, markersize, outfile)
if __name__ == '__main__':
main()