001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *   http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.compress.utils;
019
020import java.io.File;
021import java.io.IOException;
022import java.nio.ByteBuffer;
023import java.nio.channels.ClosedChannelException;
024import java.nio.channels.NonWritableChannelException;
025import java.nio.channels.SeekableByteChannel;
026import java.nio.file.Files;
027import java.nio.file.Path;
028import java.nio.file.StandardOpenOption;
029import java.util.ArrayList;
030import java.util.Arrays;
031import java.util.Collections;
032import java.util.List;
033import java.util.Objects;
034
035/**
036 * Read-Only Implementation of {@link SeekableByteChannel} that
037 * concatenates a collection of other {@link SeekableByteChannel}s.
038 *
039 * <p>This is a lose port of <a
040 * href="https://github.com/frugalmechanic/fm-common/blob/master/jvm/src/main/scala/fm/common/MultiReadOnlySeekableByteChannel.scala">MultiReadOnlySeekableByteChannel</a>
041 * by Tim Underwood.</p>
042 *
043 * @since 1.19
044 */
045public class MultiReadOnlySeekableByteChannel implements SeekableByteChannel {
046
047    private static final Path[] EMPTY_PATH_ARRAY = {};
048
049    /**
050     * Concatenates the given files.
051     *
052     * @param files the files to concatenate
053     * @throws NullPointerException if files is null
054     * @throws IOException if opening a channel for one of the files fails
055     * @return SeekableByteChannel that concatenates all provided files
056     */
057    public static SeekableByteChannel forFiles(final File... files) throws IOException {
058        final List<Path> paths = new ArrayList<>();
059        for (final File f : Objects.requireNonNull(files, "files must not be null")) {
060            paths.add(f.toPath());
061        }
062
063        return forPaths(paths.toArray(EMPTY_PATH_ARRAY));
064    }
065
066    /**
067     * Concatenates the given file paths.
068     * @param paths the file paths to concatenate, note that the LAST FILE of files should be the LAST SEGMENT(.zip)
069     * and these files should be added in correct order (e.g.: .z01, .z02... .z99, .zip)
070     * @return SeekableByteChannel that concatenates all provided files
071     * @throws NullPointerException if files is null
072     * @throws IOException if opening a channel for one of the files fails
073     * @throws IOException if the first channel doesn't seem to hold
074     * the beginning of a split archive
075     * @since 1.22
076     */
077    public static SeekableByteChannel forPaths(final Path... paths) throws IOException {
078        final List<SeekableByteChannel> channels = new ArrayList<>();
079        for (final Path path : Objects.requireNonNull(paths, "paths must not be null")) {
080            channels.add(Files.newByteChannel(path, StandardOpenOption.READ));
081        }
082        if (channels.size() == 1) {
083            return channels.get(0);
084        }
085        return new MultiReadOnlySeekableByteChannel(channels);
086    }
087
088    /**
089     * Concatenates the given channels.
090     *
091     * @param channels the channels to concatenate
092     * @throws NullPointerException if channels is null
093     * @return SeekableByteChannel that concatenates all provided channels
094     */
095    public static SeekableByteChannel forSeekableByteChannels(final SeekableByteChannel... channels) {
096        if (Objects.requireNonNull(channels, "channels must not be null").length == 1) {
097            return channels[0];
098        }
099        return new MultiReadOnlySeekableByteChannel(Arrays.asList(channels));
100    }
101
102    private final List<SeekableByteChannel> channels;
103
104    private long globalPosition;
105
106    private int currentChannelIdx;
107
108    /**
109     * Concatenates the given channels.
110     *
111     * @param channels the channels to concatenate
112     * @throws NullPointerException if channels is null
113     */
114    public MultiReadOnlySeekableByteChannel(final List<SeekableByteChannel> channels) {
115        this.channels = Collections.unmodifiableList(new ArrayList<>(
116            Objects.requireNonNull(channels, "channels must not be null")));
117    }
118
119    @Override
120    public void close() throws IOException {
121        IOException first = null;
122        for (final SeekableByteChannel ch : channels) {
123            try {
124                ch.close();
125            } catch (final IOException ex) {
126                if (first == null) {
127                    first = ex;
128                }
129            }
130        }
131        if (first != null) {
132            throw new IOException("failed to close wrapped channel", first);
133        }
134    }
135
136    @Override
137    public boolean isOpen() {
138        return channels.stream().allMatch(SeekableByteChannel::isOpen);
139    }
140
141    /**
142     * Gets this channel's position.
143     *
144     * <p>This method violates the contract of {@link SeekableByteChannel#position()} as it will not throw any exception
145     * when invoked on a closed channel. Instead it will return the position the channel had when close has been
146     * called.</p>
147     */
148    @Override
149    public long position() {
150        return globalPosition;
151    }
152
153    @Override
154    public synchronized SeekableByteChannel position(final long newPosition) throws IOException {
155        if (newPosition < 0) {
156            throw new IllegalArgumentException("Negative position: " + newPosition);
157        }
158        if (!isOpen()) {
159            throw new ClosedChannelException();
160        }
161
162        globalPosition = newPosition;
163
164        long pos = newPosition;
165
166        for (int i = 0; i < channels.size(); i++) {
167            final SeekableByteChannel currentChannel = channels.get(i);
168            final long size = currentChannel.size();
169
170            final long newChannelPos;
171            if (pos == -1L) {
172                // Position is already set for the correct channel,
173                // the rest of the channels get reset to 0
174                newChannelPos = 0;
175            } else if (pos <= size) {
176                // This channel is where we want to be
177                currentChannelIdx = i;
178                final long tmp = pos;
179                pos = -1L; // Mark pos as already being set
180                newChannelPos = tmp;
181            } else {
182                // newPosition is past this channel.  Set channel
183                // position to the end and substract channel size from
184                // pos
185                pos -= size;
186                newChannelPos = size;
187            }
188
189            currentChannel.position(newChannelPos);
190        }
191        return this;
192    }
193
194    /**
195     * Sets the position based on the given channel number and relative offset
196     *
197     * @param channelNumber  the channel number
198     * @param relativeOffset the relative offset in the corresponding channel
199     * @return global position of all channels as if they are a single channel
200     * @throws IOException if positioning fails
201     */
202    public synchronized SeekableByteChannel position(final long channelNumber, final long relativeOffset) throws IOException {
203        if (!isOpen()) {
204            throw new ClosedChannelException();
205        }
206        long globalPosition = relativeOffset;
207        for (int i = 0; i < channelNumber; i++) {
208            globalPosition += channels.get(i).size();
209        }
210
211        return position(globalPosition);
212    }
213
214    @Override
215    public synchronized int read(final ByteBuffer dst) throws IOException {
216        if (!isOpen()) {
217            throw new ClosedChannelException();
218        }
219        if (!dst.hasRemaining()) {
220            return 0;
221        }
222
223        int totalBytesRead = 0;
224        while (dst.hasRemaining() && currentChannelIdx < channels.size()) {
225            final SeekableByteChannel currentChannel = channels.get(currentChannelIdx);
226            final int newBytesRead = currentChannel.read(dst);
227            if (newBytesRead == -1) {
228                // EOF for this channel -- advance to next channel idx
229                currentChannelIdx += 1;
230                continue;
231            }
232            if (currentChannel.position() >= currentChannel.size()) {
233                // we are at the end of the current channel
234                currentChannelIdx++;
235            }
236            totalBytesRead += newBytesRead;
237        }
238        if (totalBytesRead > 0) {
239            globalPosition += totalBytesRead;
240            return totalBytesRead;
241        }
242        return -1;
243    }
244
245    @Override
246    public long size() throws IOException {
247        if (!isOpen()) {
248            throw new ClosedChannelException();
249        }
250        long acc = 0;
251        for (final SeekableByteChannel ch : channels) {
252            acc += ch.size();
253        }
254        return acc;
255    }
256
257    /**
258     * @throws NonWritableChannelException since this implementation is read-only.
259     */
260    @Override
261    public SeekableByteChannel truncate(final long size) {
262        throw new NonWritableChannelException();
263    }
264
265    /**
266     * @throws NonWritableChannelException since this implementation is read-only.
267     */
268    @Override
269    public int write(final ByteBuffer src) {
270        throw new NonWritableChannelException();
271    }
272
273}