1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
#!/usr/bin/python3
# Copyright (c) 2008-2025 the MRtrix3 contributors.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# Covered Software is provided under this License on an "as is"
# basis, without warranty of any kind, either expressed, implied, or
# statutory, including, without limitation, warranties that the
# Covered Software is free of defects, merchantable, fit for a
# particular purpose or non-infringing.
# See the Mozilla Public License v. 2.0 for more details.
#
# For more details, see http://www.mrtrix.org/.
# note: deal with these warnings properly when we drop support for Python 2:
# pylint: disable=unspecified-encoding
import json, shutil
def usage(cmdline): #pylint: disable=unused-variable
cmdline.set_author('Lena Dorfschmidt (ld548@cam.ac.uk) and Jakub Vohryzek (jakub.vohryzek@queens.ox.ac.uk) and Robert E. Smith (robert.smith@florey.edu.au)')
cmdline.set_synopsis('Concatenating multiple DWI series accounting for differential intensity scaling')
cmdline.add_description('This script concatenates two or more 4D DWI series, accounting for the '
'fact that there may be differences in intensity scaling between those series. '
'This intensity scaling is corrected by determining scaling factors that will '
'make the overall image intensities in the b=0 volumes of each series approximately '
'equivalent.')
cmdline.add_argument('inputs', nargs='+', help='Multiple input diffusion MRI series')
cmdline.add_argument('output', help='The output image series (all DWIs concatenated)')
cmdline.add_argument('-mask', metavar='image', help='Provide a binary mask within which image intensities will be matched')
def execute(): #pylint: disable=unused-variable
from mrtrix3 import CONFIG, MRtrixError #pylint: disable=no-name-in-module, import-outside-toplevel
from mrtrix3 import app, image, path, run #pylint: disable=no-name-in-module, import-outside-toplevel
num_inputs = len(app.ARGS.inputs)
if num_inputs < 2:
raise MRtrixError('Script requires at least two input image series')
# check input data
def check_header(header):
if len(header.size()) > 4:
raise MRtrixError('Image "' + header.name() + '" contains more than 4 dimensions')
if not 'dw_scheme' in header.keyval():
raise MRtrixError('Image "' + header.name() + '" does not contain a gradient table')
dw_scheme = header.keyval()['dw_scheme']
try:
if isinstance(dw_scheme[0], list):
num_grad_lines = len(dw_scheme)
elif (isinstance(dw_scheme[0], ( int, float))) and len(dw_scheme) >= 4:
num_grad_lines = 1
else:
raise MRtrixError
except (IndexError, MRtrixError):
raise MRtrixError('Image "' + header.name() + '" contains gradient table of unknown format')
if len(header.size()) == 4:
num_volumes = header.size()[3]
if num_grad_lines != num_volumes:
raise MRtrixError('Number of lines in gradient table for image "' + header.name() + '" (' + str(num_grad_lines) + ') does not match number of volumes (' + str(num_volumes) + ')')
elif not (num_grad_lines == 1 and len(dw_scheme) >= 4 and dw_scheme[3] <= float(CONFIG.get('BZeroThreshold', 10.0))):
raise MRtrixError('Image "' + header.name() + '" is 3D, and cannot be validated as a b=0 volume')
first_header = image.Header(path.from_user(app.ARGS.inputs[0], False))
check_header(first_header)
warn_protocol_mismatch = False
for filename in app.ARGS.inputs[1:]:
this_header = image.Header(path.from_user(filename, False))
check_header(this_header)
if this_header.size()[0:3] != first_header.size()[0:3]:
raise MRtrixError('Spatial dimensions of image "' + filename + '" do not match those of first image "' + first_header.name() + '"')
for field_name in [ 'EchoTime', 'RepetitionTime', 'FlipAngle' ]:
first_value = first_header.keyval().get(field_name)
this_value = this_header.keyval().get(field_name)
if first_value and this_value and first_value != this_value:
warn_protocol_mismatch = True
if warn_protocol_mismatch:
app.warn('Mismatched protocol acquisition parameters detected between input images; ' + \
'the assumption of equivalent intensities between b=0 volumes of different inputs underlying operation of this script may not be valid')
if app.ARGS.mask:
mask_header = image.Header(path.from_user(app.ARGS.mask, False))
if mask_header.size()[0:3] != first_header.size()[0:3]:
raise MRtrixError('Spatial dimensions of mask image "' + app.ARGS.mask + '" do not match those of first image "' + first_header.name() + '"')
# check output path
app.check_output_path(path.from_user(app.ARGS.output, False))
# import data to scratch directory
app.make_scratch_dir()
for index, filename in enumerate(app.ARGS.inputs):
run.command('mrconvert ' + path.from_user(filename) + ' ' + path.to_scratch(str(index) + 'in.mif'))
if app.ARGS.mask:
run.command('mrconvert ' + path.from_user(app.ARGS.mask) + ' ' + path.to_scratch('mask.mif') + ' -datatype bit')
app.goto_scratch_dir()
# extract b=0 volumes within each input series
for index in range(0, num_inputs):
infile = str(index) + 'in.mif'
outfile = str(index) + 'b0.mif'
if len(image.Header(infile).size()) > 3:
run.command('dwiextract ' + infile + ' ' + outfile + ' -bzero')
else:
run.function(shutil.copyfile, infile, outfile)
mask_option = ' -mask_input mask.mif -mask_target mask.mif' if app.ARGS.mask else ''
# for all but the first image series:
# - find multiplicative factor to match b=0 images to those of the first image
# - apply multiplicative factor to whole image series
# It would be better to not preferentially treat one of the inputs differently to any other:
# - compare all inputs to all other inputs
# - determine one single appropriate scaling factor for each image based on all results
# can't do a straight geometric average: e.g. if run for 2 images, each would map to
# the the input intensoty of the other image, and so the two corrected images would not match
# should be some mathematical theorem providing the optimal scaling factor for each image
# based on the resulting matrix of optimal scaling factors
filelist = [ '0in.mif' ]
for index in range(1, num_inputs):
stderr_text = run.command('mrhistmatch scale ' + str(index) + 'b0.mif 0b0.mif ' + str(index) + 'rescaledb0.mif' + mask_option).stderr
scaling_factor = None
for line in stderr_text.splitlines():
if 'Estimated scale factor is' in line:
try:
scaling_factor = float(line.split()[-1])
except ValueError:
raise MRtrixError('Unable to convert scaling factor from mrhistmatch output to floating-point number')
break
if scaling_factor is None:
raise MRtrixError('Unable to extract scaling factor from mrhistmatch output')
filename = str(index) + 'rescaled.mif'
run.command('mrcalc ' + str(index) + 'in.mif ' + str(scaling_factor) + ' -mult ' + filename)
filelist.append(filename)
# concatenate all series together
run.command('mrcat ' + ' '.join(filelist) + ' - -axis 3 | ' + \
'mrconvert - result.mif -json_export result_init.json -strides 0,0,0,1')
# remove current contents of command_history, since there's no sensible
# way to choose from which input image the contents should be taken;
# we do however want to keep other contents of keyval (e.g. gradient table)
with open('result_init.json', 'r') as input_json_file:
keyval = json.load(input_json_file)
keyval.pop('command_history', None)
with open('result_final.json', 'w') as output_json_file:
json.dump(keyval, output_json_file)
run.command('mrconvert result.mif ' + path.from_user(app.ARGS.output), mrconvert_keyval='result_final.json', force=app.FORCE_OVERWRITE)
# Execute the script
import mrtrix3
mrtrix3.execute() #pylint: disable=no-member
|