Distance test

[1]:
%load_ext autoreload
%autoreload 2
from creme import creme
from creme import utils
from creme import custom_model
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
2024-06-09 03:53:12.112293: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

Load Enformer and example sequences

[2]:
data_dir = '../../../data'
track_index = [5111]
model = custom_model.Enformer(track_index=track_index)

2024-06-09 03:53:13.448028: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-09 03:53:13.903910: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 58738 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:85:00.0, compute capability: 8.0
[3]:
fasta_path = f'{data_dir}/GRCh38.primary_assembly.genome.fa'
seq_parser = utils.SequenceParser(fasta_path)

gene = 'GATA2_chr3_128487916_-'

gene_name, chrom, start, strand = gene.split('_')
wt_seq = seq_parser.extract_seq_centered(chrom, int(start), strand, model.seq_length)

[4]:
# TSS bin indeces
bins = [447, 448]
[5]:
wt = model.predict(wt_seq)[0,:,0]

2024-06-09 03:53:16.898851: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2024-06-09 03:53:17.021440: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8401
[6]:
utils.plot_track([wt], color='green', zoom=[0, 896], marks=bins)

[6]:
<Axes: >
../_images/tutorials_distance_test_7_1.png

Distance test

In this example we will test the distance effect of (5Kb) Enhancer CRE of the GATA2 gene.

To run the test we need: - a loaded model - onehot encoded sequence (WT) of the sequence - a list of fixed tile coordinates [start, end] - a list of tile coordinates of the CRE - tile coordinate test positions. These are where the CRE will be embedded to test distance-based effects. - num_shuffle - number of shuffled - optionally, we can set mean=False to not average the shuffle results - optionally, we can set seed to use the same background sequences in each test position

[7]:

perturb_window = 5000 N_shuffles = 10 tss_tile, cre_tiles = utils.set_tile_range(model.seq_length, perturb_window) enhancing_cre_tile = cre_tiles[19] print(f'Enhancing tile at position {enhancing_cre_tile[0]} - {enhancing_cre_tile[1]}') print(f'TSS tile at center position {tss_tile[0]} - {tss_tile[1]}')
Enhancing tile at position 100804 - 105804
TSS tile at center position 95804 - 100804
[8]:
test_start_positions = np.array(cre_tiles)[:,0]

Distance effect of an enhancing tile of GATA2 gene TSS

[9]:
%%time
results = creme.distance_test(model, wt_seq, tss_tile, enhancing_cre_tile,
                                  test_start_positions, N_shuffles, mean=False, seed=True)
100%|███████████████████████████████████████████████████████████████████████████| 38/38 [02:14<00:00,  3.55s/it]
CPU times: user 40.2 s, sys: 318 ms, total: 40.5 s
Wall time: 2min 18s

Breakdown of results

[41]:
utils.plot_track([results['control'].mean(axis=0)[:,0]], zoom=[400,500], alpha=1, color='k')
plt.title('Mean of control sequences (CRE at original position)')
[41]:
Text(0.5, 1.0, 'Mean of control sequences (CRE at original position)')
../_images/tutorials_distance_test_14_1.png

We can plot the mean prediction (across shuffles) for each test position.

[11]:
for one_position_result in results['mut'].mean(axis=1)[...,0]:
    utils.plot_track([one_position_result], zoom=[400,500], alpha=0.4, marks=[447, 448])
/home/toneyan/paper_creme/creme-nn/tutorials/../creme/utils.py:327: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). Consider using `matplotlib.pyplot.close()`.
  fig, ax = plt.subplots(1, figsize=[20, 2])
../_images/tutorials_distance_test_16_1.png
../_images/tutorials_distance_test_16_2.png
../_images/tutorials_distance_test_16_3.png
../_images/tutorials_distance_test_16_4.png
../_images/tutorials_distance_test_16_5.png
../_images/tutorials_distance_test_16_6.png
../_images/tutorials_distance_test_16_7.png
../_images/tutorials_distance_test_16_8.png
../_images/tutorials_distance_test_16_9.png
../_images/tutorials_distance_test_16_10.png
../_images/tutorials_distance_test_16_11.png
../_images/tutorials_distance_test_16_12.png
../_images/tutorials_distance_test_16_13.png
../_images/tutorials_distance_test_16_14.png
../_images/tutorials_distance_test_16_15.png
../_images/tutorials_distance_test_16_16.png
../_images/tutorials_distance_test_16_17.png
../_images/tutorials_distance_test_16_18.png
../_images/tutorials_distance_test_16_19.png
../_images/tutorials_distance_test_16_20.png
../_images/tutorials_distance_test_16_21.png
../_images/tutorials_distance_test_16_22.png
../_images/tutorials_distance_test_16_23.png
../_images/tutorials_distance_test_16_24.png
../_images/tutorials_distance_test_16_25.png
../_images/tutorials_distance_test_16_26.png
../_images/tutorials_distance_test_16_27.png
../_images/tutorials_distance_test_16_28.png
../_images/tutorials_distance_test_16_29.png
../_images/tutorials_distance_test_16_30.png
../_images/tutorials_distance_test_16_31.png
../_images/tutorials_distance_test_16_32.png
../_images/tutorials_distance_test_16_33.png
../_images/tutorials_distance_test_16_34.png
../_images/tutorials_distance_test_16_35.png
../_images/tutorials_distance_test_16_36.png
../_images/tutorials_distance_test_16_37.png
../_images/tutorials_distance_test_16_38.png

We can also plot the summary value of TSS activity (y-axis) versus the CRE position (x-axis) and select specific points to plot in detail.

[42]:
average_result_per_position = results['mut'].mean(axis=1)[...,0][:,[447, 448]].mean(axis=-1)
selection = [0, 19, 26]
colors = ['red', 'green', 'orange']
fig, ax = plt.subplots(1)

ax.scatter(test_start_positions, average_result_per_position, color='#655C83')

for i, selection in enumerate([0, 19, 26]):
    ax.scatter(test_start_positions[selection], average_result_per_position[selection],
               facecolors='none', edgecolors=colors[i], s=80);
    ax.set_xlabel('Tile coordinate')
    ax.set_ylabel('TSS activity')
    utils.plot_track([results['mut'].mean(axis=1)[selection,:,0]], color=colors[i])

../_images/tutorials_distance_test_18_0.png
../_images/tutorials_distance_test_18_1.png
../_images/tutorials_distance_test_18_2.png
../_images/tutorials_distance_test_18_3.png
[ ]: