#vim: set syntax=python

__author__ = "Johannes Köster"
__copyright__ = "Copyright 2015, Johannes Köster, Liu lab"
__email__ = "koester@jimmy.harvard.edu"
__license__ = "MIT"

"""
A CRISPR/Cas9 analysis workflow using MAGeCK, FastQC and VISPR.
"""


configfile: "config.yaml"


import sys
import yaml
from mageck_vispr import (postprocess_config, vispr_config, get_fastq,
                          annotation_available, get_counts, design_available,
                          get_norm_method, COMBAT_SCRIPT_PATH)


postprocess_config(config)


rule all:
    input:
        expand("results/{experiment}.vispr.yaml", experiment=config["experiments"])

if "samples" in config:
    rule fastqc:
        input:
            lambda wildcards: config["replicates"][wildcards.replicate]
        output:
            "results/qc/{replicate}"
        log:
            "logs/fastqc/{replicate}.log"
        shell:
            "mkdir -p {output}; rm -rf {output}/*; "
            "fastqc -f fastq --extract -o {output} {input} 2> {log}"


    if "adapter" in config["sgrnas"]:
        rule cutadapt:
            input:
                lambda wildcards: config["replicates"][wildcards.replicate]
            output:
                "results/trimmed_reads/{replicate}.fastq"
            shell:
                "cutadapt -a {config[sgrnas][adapter]} {input} > {output}"


    rule mageck_count:
        input:
            fastqs=[get_fastq(rep, config) for rep in config["replicates"]],
            library=config["library"]
        output:
            "results/count/all.count.txt",
            "results/count/all.count_normalized.txt",
            "results/count/all.countsummary.txt"
        params:
            labels=",".join(config["samples"].keys()),
            norm=get_norm_method(config),
            fastqs=" ".join(
                ",".join(get_fastq(rep, config) for rep in replicates)
                for replicates in config["samples"].values()),
            prefix="results/count/all",
            day0=(
                "" if not "day0label" in config
                else "--day0-label "+config["day0label"]),
            controlsg=(
                "" if not "control_sgrna" in config
                else "--control-sgrna "+config["control_sgrna"])
        log:
            "logs/mageck/count/all.log"
        shell:
            "mageck count --output-prefix {params.prefix} "
            "--norm-method {params.norm} "
            "--list-seq {input.library} "
            "--fastq {params.fastqs} --sample-label {params.labels} "
            "{params.controlsg} "
            "{params.day0} --trim-5 {config[sgrnas][trim-5]} 2> {log}"


if "counts" in config:
    rule mageck_qc:
        input:
            counts=get_counts(config),
        output:
            "results/count/all.count_normalized.txt",
            "results/count/all.countsummary.txt",
            "results/count/all_countsummary.R",
            "results/count/all_countsummary.Rnw"
        params:
            prefix="results/count/all",
            norm=get_norm_method(config),
            day0=(
                "" if not "day0label" in config
                else "--day0-label "+config["day0label"]),
            controlsg=(
                "" if not "control_sgrna" in config
                else "--control-sgrna "+config["control_sgrna"])
        log:
            "logs/mageck/count/all.log"
        shell:
            "mageck count --output-prefix {params.prefix} {params.day0} "
            "--norm-method {params.norm} "
            "{params.controlsg} "
            "--count-table {input.counts} 2> {log}"


if "library" in config:
    rule annotate_sgrnas:
        input:
            config["library"]
        output:
            "annotation/sgrnas.bed"
        log:
            "logs/annotation/sgrnas.log"
        shell:
            "mageck-vispr annotate-library {input} "
            "--sgrna-len {config[sgrnas][len]} --assembly {config[assembly]} "
            "> {output} 2> {log}"


if "batchmatrix" in config:
    rule remove_batch:
        input:
            counts=config.get("counts", "results/count/all.count_normalized.txt"),
            batchmatrix=config["batchmatrix"]
        output:
            "results/count/all.count.batchcorrected.txt"
        log:
            "logs/combat.log"
        script:
            COMBAT_SCRIPT_PATH


ruleorder: mageck_mle > mageck_rra


rule mageck_rra:
    input:
        counts=get_counts(config)
    output:
        "results/test/{experiment}.gene_summary.txt",
        "results/test/{experiment}.sgrna_summary.txt"
    params:
        prefix="results/test/{experiment}",
        treatment=lambda wildcards: ",".join(config["experiments"][wildcards.experiment]["treatment"]),
        control=lambda wildcards: ",".join(config["experiments"][wildcards.experiment]["control"]),
        norm=get_norm_method(config),
        controlsg=(
            "" if not "control_sgrna" in config
            else "--control-sgrna "+config["control_sgrna"]),
        additionalparameter=(
            "" if not "additional_mle_rra_parameter" in config
            else " "+config["additional_mle_rra_parameter"])
    log:
        "logs/mageck/test/{experiment}.log"
    shell:
        "mageck test --norm-method {params.norm} "
        "--output-prefix {params.prefix} "
        "--count-table {input} --treatment-id {params.treatment} "
        "{params.controlsg} "
        "--control-id {params.control} "
        "{params.additionalparameter} "
        "2> {log} "


rule mageck_mle:
    input:
        counts=get_counts(config),
        has_designmatrix=lambda wildcards: config["experiments"][wildcards.experiment]["designmatrix"],
        annotation="annotation/sgrnas.bed" if annotation_available(config) else []
    output:
        "results/test/{experiment}.gene_summary.txt",
        "results/test/{experiment}.sgrna_summary.txt"
    params:
        prefix="results/test/{experiment}",
        efficiency=(
            "" if not annotation_available(config)
            else "--sgrna-eff-name-column 3 "
                 "--sgrna-eff-score-column 4 "
                 "--sgrna-efficiency annotation/sgrnas.bed"),
        update_efficiency=(
            "" if not config["sgrnas"].get("update-efficiency", False)
            else "--update-efficiency"),
        norm=get_norm_method(config),
        designmatrix=(lambda wildcards: "" if not design_available(config)
            else "--design-matrix " +  config["experiments"][wildcards.experiment]["designmatrix"]),
        day0=(
            "" if not "day0label" in config
            else "--day0-label "+config["day0label"]),
        controlsg=(
            "" if not "control_sgrna" in config
            else "--control-sgrna "+config["control_sgrna"]),
        threads=(
            "" if not "threads" in config
            else "--threads "+str(config["threads"])),
        additionalparameter=(
            "" if not "additional_mle_rra_parameter" in config
            else " "+config["additional_mle_rra_parameter"])
    log:
        "logs/mageck/test/{experiment}.log"
    shell:
        "mageck mle --norm-method {params.norm} "
        "--output-prefix {params.prefix} {params.efficiency} --genes-var 0 "
        "{params.update_efficiency} --count-table {input.counts} "
        "{params.threads} {params.controlsg} {params.designmatrix} {params.day0} "
        "{params.additionalparameter} "
        "2> {log}"


if "samples" in config:
    rule vispr:
        input:
            "annotation/sgrnas.bed" if annotation_available(config) else [],
            results="results/test/{experiment}.gene_summary.txt",
            sgrna_results="results/test/{experiment}.sgrna_summary.txt",
            counts=get_counts(config, normalized=True),
            mapstats="results/count/all.countsummary.txt",
            fastqc=expand("results/qc/{replicate}", replicate=config["replicates"])
        output:
            "results/{experiment}.vispr.yaml"
        run:
            vispr_config(input, output, wildcards, config)
else:
    rule vispr:
        input:
            "annotation/sgrnas.bed" if annotation_available(config) else [],
            results="results/test/{experiment}.gene_summary.txt",
            sgrna_results="results/test/{experiment}.sgrna_summary.txt",
            counts="results/count/all.count_normalized.txt"
        output:
            "results/{experiment}.vispr.yaml"
        run:
            vispr_config(input, output, wildcards, config)
