""" Testing """ import platform import os import random import sys import zlib from io import BytesIO from tempfile import mkstemp from contextlib import contextmanager import numpy as np from numpy.testing import assert_, assert_equal from pytest import raises as assert_raises import pytest from scipy.io.matlab._streams import (make_stream, GenericStream, ZlibInputStream, _read_into, _read_string, BLOCK_SIZE) @contextmanager def setup_test_file(): val = b'a\x00string' fd, fname = mkstemp() with os.fdopen(fd, 'wb') as fs: fs.write(val) with open(fname, 'rb') as fs: gs = BytesIO(val) cs = BytesIO(val) yield fs, gs, cs os.unlink(fname) def test_make_stream(): with setup_test_file() as (fs, gs, cs): # test stream initialization assert_(isinstance(make_stream(gs), GenericStream)) def test_tell_seek(): with setup_test_file() as (fs, gs, cs): for s in (fs, gs, cs): st = make_stream(s) res = st.seek(0) assert_equal(res, 0) assert_equal(st.tell(), 0) res = st.seek(5) assert_equal(res, 0) assert_equal(st.tell(), 5) res = st.seek(2, 1) assert_equal(res, 0) assert_equal(st.tell(), 7) res = st.seek(-2, 2) assert_equal(res, 0) assert_equal(st.tell(), 6) def test_read(): with setup_test_file() as (fs, gs, cs): for s in (fs, gs, cs): st = make_stream(s) st.seek(0) res = st.read(-1) assert_equal(res, b'a\x00string') st.seek(0) res = st.read(4) assert_equal(res, b'a\x00st') # read into st.seek(0) res = _read_into(st, 4) assert_equal(res, b'a\x00st') res = _read_into(st, 4) assert_equal(res, b'ring') assert_raises(OSError, _read_into, st, 2) # read alloc st.seek(0) res = _read_string(st, 4) assert_equal(res, b'a\x00st') res = _read_string(st, 4) assert_equal(res, b'ring') assert_raises(OSError, _read_string, st, 2) class TestZlibInputStream: def _get_data(self, size): data = random.randbytes(size) compressed_data = zlib.compress(data) stream = BytesIO(compressed_data) return stream, len(compressed_data), data def test_read(self): SIZES = [0, 1, 10, BLOCK_SIZE//2, BLOCK_SIZE-1, BLOCK_SIZE, BLOCK_SIZE+1, 2*BLOCK_SIZE-1] READ_SIZES = [BLOCK_SIZE//2, BLOCK_SIZE-1, BLOCK_SIZE, BLOCK_SIZE+1] def check(size, read_size): compressed_stream, compressed_data_len, data = self._get_data(size) stream = ZlibInputStream(compressed_stream, compressed_data_len) data2 = b'' so_far = 0 while True: block = stream.read(min(read_size, size - so_far)) if not block: break so_far += len(block) data2 += block assert_equal(data, data2) for size in SIZES: for read_size in READ_SIZES: check(size, read_size) def test_read_max_length(self): data = random.randbytes(1234) compressed_data = zlib.compress(data) compressed_stream = BytesIO(compressed_data + b"abbacaca") stream = ZlibInputStream(compressed_stream, len(compressed_data)) stream.read(len(data)) assert_equal(compressed_stream.tell(), len(compressed_data)) assert_raises(OSError, stream.read, 1) def test_read_bad_checksum(self): data = random.randbytes(10) compressed_data = zlib.compress(data) # break checksum compressed_data = (compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])) compressed_stream = BytesIO(compressed_data) stream = ZlibInputStream(compressed_stream, len(compressed_data)) assert_raises(zlib.error, stream.read, len(data)) def test_seek(self): compressed_stream, compressed_data_len, data = self._get_data(1024) stream = ZlibInputStream(compressed_stream, compressed_data_len) stream.seek(123) p = 123 assert_equal(stream.tell(), p) d1 = stream.read(11) assert_equal(d1, data[p:p+11]) stream.seek(321, 1) p = 123+11+321 assert_equal(stream.tell(), p) d2 = stream.read(21) assert_equal(d2, data[p:p+21]) stream.seek(641, 0) p = 641 assert_equal(stream.tell(), p) d3 = stream.read(11) assert_equal(d3, data[p:p+11]) assert_raises(OSError, stream.seek, 10, 2) assert_raises(OSError, stream.seek, -1, 1) assert_raises(ValueError, stream.seek, 1, 123) stream.seek(10000, 1) assert_raises(OSError, stream.read, 12) def test_seek_bad_checksum(self): data = random.randbytes(10) compressed_data = zlib.compress(data) # break checksum compressed_data = (compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])) compressed_stream = BytesIO(compressed_data) stream = ZlibInputStream(compressed_stream, len(compressed_data)) assert_raises(zlib.error, stream.seek, len(data)) def test_all_data_read(self): compressed_stream, compressed_data_len, data = self._get_data(1024) stream = ZlibInputStream(compressed_stream, compressed_data_len) assert_(not stream.all_data_read()) stream.seek(512) assert_(not stream.all_data_read()) stream.seek(1024) assert_(stream.all_data_read()) @pytest.mark.skipif( (platform.system() == 'Windows' and sys.version_info >= (3, 14)), reason='gh-23185') def test_all_data_read_overlap(self): COMPRESSION_LEVEL = 6 data = np.arange(33707000, dtype=np.uint8) compressed_data = zlib.compress(data, COMPRESSION_LEVEL) compressed_data_len = len(compressed_data) # check that part of the checksum overlaps assert_(compressed_data_len == BLOCK_SIZE + 2) compressed_stream = BytesIO(compressed_data) stream = ZlibInputStream(compressed_stream, compressed_data_len) assert_(not stream.all_data_read()) stream.seek(len(data)) assert_(stream.all_data_read()) @pytest.mark.skipif( (platform.system() == 'Windows' and sys.version_info >= (3, 14)), reason='gh-23185') def test_all_data_read_bad_checksum(self): COMPRESSION_LEVEL = 6 data = np.arange(33707000, dtype=np.uint8) compressed_data = zlib.compress(data, COMPRESSION_LEVEL) compressed_data_len = len(compressed_data) # check that part of the checksum overlaps assert_(compressed_data_len == BLOCK_SIZE + 2) # break checksum compressed_data = (compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])) compressed_stream = BytesIO(compressed_data) stream = ZlibInputStream(compressed_stream, compressed_data_len) assert_(not stream.all_data_read()) stream.seek(len(data)) assert_raises(zlib.error, stream.all_data_read)