Median of K Sorted Arrays

LintCode 931

There are ksorted arrays nums. Find the median of the given ksorted arrays.

Notice

The length of the given arrays maynot equalto each other. The elements of the given arrays are allpositivenumber. Return 0if there are no elements in the array.

Example

Given nums =[[1],[2],[3]], return 2.00.

这题难度起始比上一题 median of two sorted arrays 要简单,因为不需要找到 O(logN) 之类的解 (网上没找到)

第一个思路当然就是 heap 了。时间复杂度是 O(M logK)。M是所有数组总长度,如果每个数组平均长度是 N,那么时间复杂度是 O(NK log K)。这个解法能出争取结果,不过在 88% 的时候碰到一个很长的 test case TLE 了

    def findMedian(self, nums):
        import heapq
        length = 0
        minheap = []
        for i, array in enumerate(nums):
            if array:
                heapq.heappush(minheap, (array[0], i, 0))
                length += len(array)

        if length == 0:
            return 0.0

        idx1, idx2 = (length - 1) // 2, length // 2

        count = 0
        while minheap:
            n, i, j = heapq.heappop(minheap)
            if count == idx1:
                median1 = n
            if count == idx2:
                median2 = n
                break
            if j + 1 < len(nums[i]):
                heapq.heappush(minheap, (nums[i][j + 1], i, j + 1))
            count += 1

        return (median1 + median2) / 2.0

第二个思路就是对数值二分。反正这种很搞的方法竟然还挺快。python 勉强能过。我看九章给的java答案,直接可以从 [0, sys.maxint] 开始二分。如果python code 这样写,在 77% 左右会 TLE。所以下面先算了下最大值。

这个时间复杂度是 O(log(range)*log(N)*K)。因为range是 0 ~ 2^31 - 1 (正整数取值范围),所以log(range)最多是31

LintCode上,python code的连这个常数项也要优化下才能AC

import bisect

class Solution:
    """
    @param nums: the given k sorted arrays
    @return: the median of the given k sorted arrays
    """
    def findMedian(self, nums):
        # write your code here
        length, self.maxValue = 0, 0
        for array in nums:
            if array:
                length += len(array)
                self.maxValue = max(self.maxValue, array[-1])

        if length == 0:
            return 0.0

        if length % 2 == 1:
            return self.findKth(nums, length // 2 + 1) * 1.0

        return (self.findKth(nums, length // 2) + self.findKth(nums, length // 2 + 1)) / 2.0


    def findKth(self, nums, k):
        lo, hi = 0, self.maxValue
        while lo < hi:
            mid = (lo + hi) // 2
            n = self.findLessEqual(nums, mid)
            if n < k:
                lo = mid + 1
            else:
                hi = mid
        return lo


    def findLessEqual(self, nums, value):
        count = 0
        for array in nums:
            count += bisect.bisect(array, value)
        return count

Last updated

Was this helpful?